Operations

namespace ops

Enums

enum class OpType

Values:

enumerator Add
enumerator Subtract
enumerator Multiply
enumerator Divide
enumerator Power
enumerator Modulo
enumerator Equal
enumerator NotEqual
enumerator Less
enumerator LessEqual
enumerator Greater
enumerator GreaterEqual
enumerator LogicalAnd
enumerator LogicalOr
enumerator LogicalXor
enumerator LogicalNot
enumerator BitwiseAnd
enumerator BitwiseOr
enumerator BitwiseXor
enumerator LeftShift
enumerator RightShift
enumerator Maximum
enumerator Minimum
enumerator Atan2
enumerator Hypot
enumerator Negate
enumerator Abs
enumerator Sqrt
enumerator Exp
enumerator Log
enumerator Sin
enumerator Cos
enumerator Tan
enumerator Erf
enumerator Sign
enumerator Floor
enumerator Ceil
enumerator Trunc
enumerator Round
enumerator Reciprocal
enumerator Square
enumerator Cbrt
enumerator IsNaN
enumerator IsInf
enumerator IsFinite
enumerator Conj
enumerator Real
enumerator Imag
enumerator ReLU
enumerator LeakyReLU
enumerator SiLU
enumerator Sigmoid
enumerator Tanh
enumerator GELU
enumerator Softmax
enumerator LogSoftmax
enumerator Sum
enumerator Mean
enumerator Max
enumerator Min
enumerator ArgMax
enumerator ArgMin
enumerator Any
enumerator All
enumerator Prod
enumerator MatMul
enumerator BatchMatMul
enumerator Where
enumerator MaskedFill
enumerator MaskedSelect
enumerator Gather
enumerator Scatter
enumerator IndexSelect
enumerator Take
enumerator TakeAlongAxis
enumerator LayerNorm
enumerator RMSNorm
enumerator Dropout
enumerator Cast
enumerator MaxPool1D
enumerator MaxPool2D
enumerator MaxPool3D
enumerator AvgPool1D
enumerator AvgPool2D
enumerator AvgPool3D
enumerator AdaptiveMaxPool2D
enumerator AdaptiveAvgPool2D
enumerator Conv1D
enumerator Conv2D
enumerator ConvTranspose1D
enumerator ConvTranspose2D
enumerator ScaledDotProductAttention
enumerator Reshape
enumerator Transpose
enumerator Pad
enumerator Slice
enumerator Unsqueeze
enumerator Concat
enumerator GLU
enumerator BatchNorm1D
enumerator Sort
enumerator Argsort
enumerator TopK
enumerator _Count
enum class InterpolateMode

Values:

enumerator Nearest
enumerator Bilinear
enumerator Bicubic

Functions

template<typename Func>
Tensor apply(const Tensor &input, Func &&func)
template<typename Func>
Tensor apply(const Tensor &a, const Tensor &b, Func &&func)
template<typename Func>
auto vectorize(Func &&func)
template<typename Func>
Tensor apply_along_axis(Func &&func1d, int axis, const Tensor &arr)
template<typename Func>
Tensor apply_over_axes(Func &&func, const Tensor &a, const std::vector<int> &axes)
template<typename Func>
auto fromfunc(Func &&func)
BroadcastInfo compute_broadcast_info(const Shape &lhs_shape, const Shape &rhs_shape)
bool are_broadcastable(const Shape &lhs_shape, const Shape &rhs_shape)
Shape broadcast_shapes(const std::vector<Shape> &shapes)
std::vector<Tensor> broadcast_tensors(const std::vector<Tensor> &tensors)
DType promote_types(DType lhs_dtype, DType rhs_dtype)
DType result_type(const Tensor &lhs, const Tensor &rhs)
Tensor add(const Tensor &lhs, const Tensor &rhs)
Tensor subtract(const Tensor &lhs, const Tensor &rhs)
Tensor multiply(const Tensor &lhs, const Tensor &rhs)
Tensor divide(const Tensor &lhs, const Tensor &rhs)
Tensor power(const Tensor &lhs, const Tensor &rhs)
Tensor modulo(const Tensor &lhs, const Tensor &rhs)
Tensor equal(const Tensor &lhs, const Tensor &rhs)
Tensor not_equal(const Tensor &lhs, const Tensor &rhs)
Tensor less(const Tensor &lhs, const Tensor &rhs)
Tensor less_equal(const Tensor &lhs, const Tensor &rhs)
Tensor greater(const Tensor &lhs, const Tensor &rhs)
Tensor greater_equal(const Tensor &lhs, const Tensor &rhs)
Tensor logical_and(const Tensor &lhs, const Tensor &rhs)
Tensor logical_or(const Tensor &lhs, const Tensor &rhs)
Tensor logical_xor(const Tensor &lhs, const Tensor &rhs)
Tensor logical_not(const Tensor &input)
Tensor bitwise_and(const Tensor &lhs, const Tensor &rhs)
Tensor bitwise_or(const Tensor &lhs, const Tensor &rhs)
Tensor bitwise_xor(const Tensor &lhs, const Tensor &rhs)
Tensor left_shift(const Tensor &lhs, const Tensor &rhs)
Tensor right_shift(const Tensor &lhs, const Tensor &rhs)
Tensor maximum(const Tensor &lhs, const Tensor &rhs)
Tensor minimum(const Tensor &lhs, const Tensor &rhs)
Tensor atan2(const Tensor &lhs, const Tensor &rhs)
Tensor hypot(const Tensor &lhs, const Tensor &rhs)
Tensor negate(const Tensor &input)
Tensor abs(const Tensor &input)
Tensor sqrt(const Tensor &input)
Tensor exp(const Tensor &input)
Tensor log(const Tensor &input)
Tensor sin(const Tensor &input)
Tensor cos(const Tensor &input)
Tensor tan(const Tensor &input)
Tensor erf(const Tensor &input)
Tensor sign(const Tensor &input)
Tensor floor(const Tensor &input)
Tensor ceil(const Tensor &input)
Tensor trunc(const Tensor &input)
Tensor round(const Tensor &input, int decimals = 0)
Tensor reciprocal(const Tensor &input)
Tensor square(const Tensor &input)
Tensor cbrt(const Tensor &input)
Tensor isnan(const Tensor &input)
Tensor isinf(const Tensor &input)
Tensor isfinite(const Tensor &input)
Tensor clip(const Tensor &input, const Tensor &min_val, const Tensor &max_val)
Tensor conj(const Tensor &input)
Tensor real(const Tensor &input)
Tensor imag(const Tensor &input)
Tensor relu(const Tensor &input)
Tensor leaky_relu(const Tensor &input, float negative_slope = 0.01f)
Tensor silu(const Tensor &input)
Tensor sigmoid(const Tensor &input)
Tensor tanh(const Tensor &input)
Tensor gelu(const Tensor &input)
Tensor softmax(const Tensor &input, int axis = -1)
Tensor log_softmax(const Tensor &input, int axis = -1)
Tensor glu(const Tensor &input, int dim = -1)
Tensor sum(const Tensor &input, const std::vector<int> &axis = {}, bool keep_dims = false)
Tensor mean(const Tensor &input, const std::vector<int> &axis = {}, bool keep_dims = false)
Tensor max(const Tensor &input, const std::vector<int> &axis = {}, bool keep_dims = false)
Tensor min(const Tensor &input, const std::vector<int> &axis = {}, bool keep_dims = false)
Tensor argmax(const Tensor &input, int axis = -1, bool keep_dims = false)
Tensor argmin(const Tensor &input, int axis = -1, bool keep_dims = false)
Tensor any(const Tensor &input, const std::vector<int> &axis = {}, bool keep_dims = false)
Tensor all(const Tensor &input, const std::vector<int> &axis = {}, bool keep_dims = false)
Tensor prod(const Tensor &input, const std::vector<int> &axis = {}, bool keep_dims = false)
Tensor matmul(const Tensor &a, const Tensor &b, bool transpose_a = false, bool transpose_b = false)
Tensor where(const Tensor &condition, const Tensor &a, const Tensor &b)
Tensor masked_fill(const Tensor &input, const Tensor &mask, float value)
Tensor masked_fill(const Tensor &input, const Tensor &mask, double value)
Tensor masked_fill(const Tensor &input, const Tensor &mask, const Tensor &value)
Tensor masked_select(const Tensor &input, const Tensor &mask)
Tensor gather(const Tensor &input, int dim, const Tensor &indices)
Tensor scatter(const Tensor &input, int dim, const Tensor &indices, const Tensor &src)
Tensor index_select(const Tensor &input, int dim, const Tensor &indices)
Tensor take(const Tensor &input, const Tensor &indices, int axis = -1)
Tensor take_along_axis(const Tensor &input, const Tensor &indices, int axis)
Tensor put_along_axis(const Tensor &input, const Tensor &indices, const Tensor &values, int axis)
Tensor embedding(const Tensor &weight, const Tensor &indices)
Tensor layer_norm(const Tensor &input, const Tensor &weight, const Tensor &bias, int axis = -1, float eps = 1e-5f)
Tensor rms_norm(const Tensor &input, const Tensor &weight, int axis = -1, float eps = 1e-5f)
std::pair<Tensor, Tensor> dropout(const Tensor &input, float p = 0.5f, bool training = true)
void add_inplace(Tensor &lhs, const Tensor &rhs)
void subtract_inplace(Tensor &lhs, const Tensor &rhs)
void multiply_inplace(Tensor &lhs, const Tensor &rhs)
void divide_inplace(Tensor &lhs, const Tensor &rhs)
void execute_binary_inplace(OpType op_type, Tensor &lhs, const Tensor &rhs)
std::vector<Tensor> meshgrid(const std::vector<Tensor> &tensors, const std::string &indexing = "xy")
Tensor pad(const Tensor &input, const std::vector<std::pair<size_t, size_t>> &pad_width, const std::string &mode = "constant", double value = 0.0)
Tensor atleast_1d(const Tensor &tensor)
Tensor atleast_2d(const Tensor &tensor)
Tensor atleast_3d(const Tensor &tensor)
std::vector<Tensor> atleast_1d(const std::vector<Tensor> &tensors)
std::vector<Tensor> atleast_2d(const std::vector<Tensor> &tensors)
std::vector<Tensor> atleast_3d(const std::vector<Tensor> &tensors)
Tensor max_pool1d(const Tensor &input, int kernel_size, int stride = 1, int padding = 0)
Tensor avg_pool1d(const Tensor &input, int kernel_size, int stride = 1, int padding = 0, bool count_include_pad = true)
Tensor max_pool2d(const Tensor &input, const std::vector<int> &kernel_size, const std::vector<int> &stride = {}, const std::vector<int> &padding = {})
Tensor avg_pool2d(const Tensor &input, const std::vector<int> &kernel_size, const std::vector<int> &stride = {}, const std::vector<int> &padding = {}, bool count_include_pad = true)
Tensor max_pool3d(const Tensor &input, const std::vector<int> &kernel_size, const std::vector<int> &stride = {}, const std::vector<int> &padding = {})
Tensor avg_pool3d(const Tensor &input, const std::vector<int> &kernel_size, const std::vector<int> &stride = {}, const std::vector<int> &padding = {}, bool count_include_pad = true)
Tensor adaptive_max_pool2d(const Tensor &input, const std::vector<int> &output_size)
Tensor adaptive_avg_pool2d(const Tensor &input, const std::vector<int> &output_size)
Tensor adaptive_max_pool1d(const Tensor &input, int output_size)
Tensor adaptive_avg_pool1d(const Tensor &input, int output_size)
Tensor conv1d(const Tensor &input, const Tensor &weight, const Tensor &bias = Tensor(), int stride = 1, int padding = 0, int dilation = 1, int groups = 1)
Tensor conv2d(const Tensor &input, const Tensor &weight, const Tensor &bias = Tensor(), std::array<int, 2> stride = {1, 1}, std::array<int, 2> padding = {0, 0}, std::array<int, 2> dilation = {1, 1}, int groups = 1)
Tensor conv_transpose1d(const Tensor &input, const Tensor &weight, const Tensor &bias = Tensor(), int stride = 1, int padding = 0, int output_padding = 0, int dilation = 1, int groups = 1)
Tensor conv_transpose2d(const Tensor &input, const Tensor &weight, const Tensor &bias = Tensor(), std::array<int, 2> stride = {1, 1}, std::array<int, 2> padding = {0, 0}, std::array<int, 2> output_padding = {0, 0}, std::array<int, 2> dilation = {1, 1}, int groups = 1)
Tensor interpolate(const Tensor &input, const std::vector<size_t> &size = {}, const std::vector<float> &scale_factor = {}, InterpolateMode mode = InterpolateMode::Nearest, bool align_corners = false)
Tensor scaled_dot_product_attention(const Tensor &query, const Tensor &key, const Tensor &value, const Tensor &mask = Tensor(), float scale = -1.0f, bool is_causal = false)
Tensor sort(const Tensor &input, int axis = -1, bool descending = false)
Tensor argsort(const Tensor &input, int axis = -1, bool descending = false)
std::pair<Tensor, Tensor> topk(const Tensor &input, int k, int axis = -1, bool largest = true, bool sorted = true)
struct BroadcastInfo
#include <operations.hpp>

Public Members

Shape result_shape
std::vector<int> lhs_strides_adjustment
std::vector<int> rhs_strides_adjustment
bool needs_broadcast
class Operation
#include <operations.hpp>

Public Functions

virtual ~Operation() = default
virtual OpType type() const = 0
virtual std::string name() const = 0
virtual Device device() const = 0
inline virtual bool supports_binary(const Tensor &lhs, const Tensor &rhs) const
virtual Tensor execute_binary(const Tensor &lhs, const Tensor &rhs) const = 0
virtual Tensor execute_unary(const Tensor &input) const
virtual Tensor execute_reduction(const Tensor &input, const std::vector<int> &axis, bool keep_dims) const
virtual Tensor execute_matmul(const Tensor &a, const Tensor &b, bool transpose_a, bool transpose_b) const
virtual Tensor execute_where(const Tensor &condition, const Tensor &a, const Tensor &b) const
virtual Tensor execute_masked_fill(const Tensor &input, const Tensor &mask, const Tensor &value) const
virtual Tensor execute_masked_select(const Tensor &input, const Tensor &mask) const
virtual Tensor execute_gather(const Tensor &input, int dim, const Tensor &indices) const
virtual Tensor execute_scatter(const Tensor &input, int dim, const Tensor &indices, const Tensor &src) const
virtual Tensor execute_index_select(const Tensor &input, int dim, const Tensor &indices) const
virtual Tensor execute_cast(const Tensor &input, DType target_dtype) const
virtual void execute_binary_inplace(Tensor &lhs, const Tensor &rhs) const
class OperationRegistry
#include <operations.hpp>

Public Static Functions

static void register_operation(OpType op_type, Device device, std::unique_ptr<Operation> operation)
static const Operation *get_operation(OpType op_type, Device device)
static std::vector<Device> available_devices_for_operation(OpType op_type)
static bool is_operation_available(OpType op_type, Device device)
static void initialize_builtin_operations()
namespace detail
template<typename T>
struct callable_traits
#include <functors.hpp>
template<typename C, typename R, typename A>
struct callable_traits<R (C::*)(A) const>
#include <functors.hpp>

Public Types

using return_type = R
using arg_type = A

Public Static Attributes

static constexpr size_t arity = 1
template<typename C, typename R, typename A>
struct callable_traits<R (C::*)(A)>
#include <functors.hpp>

Public Types

using return_type = R
using arg_type = A

Public Static Attributes

static constexpr size_t arity = 1
template<typename C, typename R, typename A1, typename A2>
struct callable_traits<R (C::*)(A1, A2) const>
#include <functors.hpp>

Public Types

using return_type = R
using arg1_type = A1
using arg2_type = A2

Public Static Attributes

static constexpr size_t arity = 2
template<typename C, typename R, typename A1, typename A2>
struct callable_traits<R (C::*)(A1, A2)>
#include <functors.hpp>

Public Types

using return_type = R
using arg1_type = A1
using arg2_type = A2

Public Static Attributes

static constexpr size_t arity = 2