Operations¶
-
namespace ops¶
Enums
-
enum class OpType¶
Values:
-
enumerator Add¶
-
enumerator Subtract¶
-
enumerator Multiply¶
-
enumerator Divide¶
-
enumerator Power¶
-
enumerator Modulo¶
-
enumerator Equal¶
-
enumerator NotEqual¶
-
enumerator Less¶
-
enumerator LessEqual¶
-
enumerator Greater¶
-
enumerator GreaterEqual¶
-
enumerator LogicalAnd¶
-
enumerator LogicalOr¶
-
enumerator LogicalXor¶
-
enumerator LogicalNot¶
-
enumerator BitwiseAnd¶
-
enumerator BitwiseOr¶
-
enumerator BitwiseXor¶
-
enumerator LeftShift¶
-
enumerator RightShift¶
-
enumerator Maximum¶
-
enumerator Minimum¶
-
enumerator Atan2¶
-
enumerator Hypot¶
-
enumerator Negate¶
-
enumerator Abs¶
-
enumerator Sqrt¶
-
enumerator Exp¶
-
enumerator Log¶
-
enumerator Sin¶
-
enumerator Cos¶
-
enumerator Tan¶
-
enumerator Erf¶
-
enumerator Sign¶
-
enumerator Floor¶
-
enumerator Ceil¶
-
enumerator Trunc¶
-
enumerator Round¶
-
enumerator Reciprocal¶
-
enumerator Square¶
-
enumerator Cbrt¶
-
enumerator IsNaN¶
-
enumerator IsInf¶
-
enumerator IsFinite¶
-
enumerator Conj¶
-
enumerator Real¶
-
enumerator Imag¶
-
enumerator ReLU¶
-
enumerator LeakyReLU¶
-
enumerator SiLU¶
-
enumerator Sigmoid¶
-
enumerator Tanh¶
-
enumerator GELU¶
-
enumerator Softmax¶
-
enumerator LogSoftmax¶
-
enumerator Sum¶
-
enumerator Mean¶
-
enumerator Max¶
-
enumerator Min¶
-
enumerator ArgMax¶
-
enumerator ArgMin¶
-
enumerator Any¶
-
enumerator All¶
-
enumerator Prod¶
-
enumerator MatMul¶
-
enumerator BatchMatMul¶
-
enumerator Where¶
-
enumerator MaskedFill¶
-
enumerator MaskedSelect¶
-
enumerator Gather¶
-
enumerator Scatter¶
-
enumerator IndexSelect¶
-
enumerator Take¶
-
enumerator TakeAlongAxis¶
-
enumerator LayerNorm¶
-
enumerator RMSNorm¶
-
enumerator Dropout¶
-
enumerator Cast¶
-
enumerator MaxPool1D¶
-
enumerator MaxPool2D¶
-
enumerator MaxPool3D¶
-
enumerator AvgPool1D¶
-
enumerator AvgPool2D¶
-
enumerator AvgPool3D¶
-
enumerator AdaptiveMaxPool2D¶
-
enumerator AdaptiveAvgPool2D¶
-
enumerator Conv1D¶
-
enumerator Conv2D¶
-
enumerator ConvTranspose1D¶
-
enumerator ConvTranspose2D¶
-
enumerator ScaledDotProductAttention¶
-
enumerator Reshape¶
-
enumerator Transpose¶
-
enumerator Pad¶
-
enumerator Slice¶
-
enumerator Unsqueeze¶
-
enumerator Concat¶
-
enumerator GLU¶
-
enumerator BatchNorm1D¶
-
enumerator Sort¶
-
enumerator Argsort¶
-
enumerator TopK¶
-
enumerator _Count¶
-
enumerator Add¶
Functions
-
template<typename Func>
Tensor apply_over_axes(Func &&func, const Tensor &a, const std::vector<int> &axes)¶
-
BroadcastInfo compute_broadcast_info(const Shape &lhs_shape, const Shape &rhs_shape)¶
-
bool are_broadcastable(const Shape &lhs_shape, const Shape &rhs_shape)¶
-
Shape broadcast_shapes(const std::vector<Shape> &shapes)¶
-
Tensor matmul(const Tensor &a, const Tensor &b, bool transpose_a = false, bool transpose_b = false)¶
-
Tensor layer_norm(const Tensor &input, const Tensor &weight, const Tensor &bias, int axis = -1, float eps = 1e-5f)¶
-
std::vector<Tensor> meshgrid(const std::vector<Tensor> &tensors, const std::string &indexing = "xy")¶
-
Tensor pad(const Tensor &input, const std::vector<std::pair<size_t, size_t>> &pad_width, const std::string &mode = "constant", double value = 0.0)¶
-
Tensor avg_pool1d(const Tensor &input, int kernel_size, int stride = 1, int padding = 0, bool count_include_pad = true)¶
-
Tensor max_pool2d(const Tensor &input, const std::vector<int> &kernel_size, const std::vector<int> &stride = {}, const std::vector<int> &padding = {})¶
-
Tensor avg_pool2d(const Tensor &input, const std::vector<int> &kernel_size, const std::vector<int> &stride = {}, const std::vector<int> &padding = {}, bool count_include_pad = true)¶
-
Tensor max_pool3d(const Tensor &input, const std::vector<int> &kernel_size, const std::vector<int> &stride = {}, const std::vector<int> &padding = {})¶
-
Tensor avg_pool3d(const Tensor &input, const std::vector<int> &kernel_size, const std::vector<int> &stride = {}, const std::vector<int> &padding = {}, bool count_include_pad = true)¶
-
Tensor conv1d(const Tensor &input, const Tensor &weight, const Tensor &bias = Tensor(), int stride = 1, int padding = 0, int dilation = 1, int groups = 1)¶
-
Tensor conv2d(const Tensor &input, const Tensor &weight, const Tensor &bias = Tensor(), std::array<int, 2> stride = {1, 1}, std::array<int, 2> padding = {0, 0}, std::array<int, 2> dilation = {1, 1}, int groups = 1)¶
-
Tensor conv_transpose1d(const Tensor &input, const Tensor &weight, const Tensor &bias = Tensor(), int stride = 1, int padding = 0, int output_padding = 0, int dilation = 1, int groups = 1)¶
-
Tensor conv_transpose2d(const Tensor &input, const Tensor &weight, const Tensor &bias = Tensor(), std::array<int, 2> stride = {1, 1}, std::array<int, 2> padding = {0, 0}, std::array<int, 2> output_padding = {0, 0}, std::array<int, 2> dilation = {1, 1}, int groups = 1)¶
-
Tensor interpolate(const Tensor &input, const std::vector<size_t> &size = {}, const std::vector<float> &scale_factor = {}, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false)¶
-
struct BroadcastInfo¶
- #include <operations.hpp>
-
class Operation¶
- #include <operations.hpp>
Public Functions
-
virtual ~Operation() = default¶
-
virtual std::string name() const = 0¶
-
virtual Tensor execute_reduction(const Tensor &input, const std::vector<int> &axis, bool keep_dims) const¶
-
virtual Tensor execute_matmul(const Tensor &a, const Tensor &b, bool transpose_a, bool transpose_b) const¶
-
virtual Tensor execute_masked_fill(const Tensor &input, const Tensor &mask, const Tensor &value) const¶
-
virtual ~Operation() = default¶
-
class OperationRegistry¶
- #include <operations.hpp>
-
namespace detail¶
-
template<typename T>
struct callable_traits¶ - #include <functors.hpp>
-
template<typename C, typename R, typename A>
struct callable_traits<R (C::*)(A) const>¶ - #include <functors.hpp>
Public Static Attributes
-
static constexpr size_t arity = 1¶
-
static constexpr size_t arity = 1¶
-
template<typename C, typename R, typename A>
struct callable_traits<R (C::*)(A)>¶ - #include <functors.hpp>
Public Static Attributes
-
static constexpr size_t arity = 1¶
-
static constexpr size_t arity = 1¶
-
template<typename T>
-
enum class OpType¶