Einops

namespace einops

Typedefs

using AxisElement = std::variant<SimpleAxis, GroupedAxes, UnityAxis, EllipsisAxis>

Enums

enum class AxisElementType

Values:

enumerator Simple
enumerator Grouped
enumerator Unity
enumerator Ellipsis

Functions

Tensor rearrange(const Tensor &tensor, const std::string &pattern, const std::map<std::string, size_t> &axis_sizes = {})

Rearrange tensor according to einops pattern.

Parameters:
  • tensor – Input tensor

  • pattern – Einops pattern string like “b h w c -> b c h w”

  • axis_sizes – Optional axis size specifications for splitting

Returns:

Rearranged tensor

Tensor reduce(const Tensor &tensor, const std::string &pattern, const std::string &reduction, const std::map<std::string, size_t> &axis_sizes = {})

Reduce tensor according to einops pattern.

Example: reduce(x, “b h w c -> b c”, “mean”) // Pool over h and w reduce(x, “b (h h2) (w w2) c -> b h w c”, “mean”, {{“h2”, 2}, {“w2”, 2}})

Parameters:
  • tensor – Input tensor

  • pattern – Einops pattern string like “b h w c -> b c” (axes not in output are reduced)

  • reduction – Reduction operation: “sum”, “mean”, “max”, “min”, “prod”

  • axis_sizes – Optional axis size specifications

Returns:

Reduced tensor

Tensor einsum(const std::string &equation, const std::vector<Tensor> &operands)

Einstein summation convention.

Supported patterns: einsum(“ij,jk->ik”, {A, B}) // matrix multiply einsum(“ii->”, {A}) // trace einsum(“ij->ji”, {A}) // transpose einsum(“ij,ij->ij”, {A, B}) // element-wise multiply einsum(“bij,bjk->bik”, {A, B}) // batched matmul einsum(“ijk->”, {A}) // sum all elements einsum(“ij->j”, {A}) // sum over rows

Parameters:
  • equation – Einsum equation like “ij,jk->ik” for matrix multiply

  • operands – Vector of input tensors

Returns:

Result tensor

Tensor repeat(const Tensor &tensor, const std::string &pattern, const std::map<std::string, size_t> &axis_sizes = {})

Repeat tensor according to einops pattern (add/tile dimensions)

Example: repeat(x, “h w -> h w c”, {{“c”, 3}}) // Add channel dim repeat(x, “h w -> h repeat w”, {{“repeat”, 3}}) // Tile rows

Parameters:
  • tensor – Input tensor

  • pattern – Einops pattern string like “h w -> h w c” (new axes in output)

  • axis_sizes – Sizes for new axes

Returns:

Repeated tensor

std::pair<Tensor, std::vector<Shape>> pack(const std::vector<Tensor> &tensors, const std::string &pattern)

Pack multiple tensors into a single tensor with a wildcard dimension.

Example: auto [packed, ps] = pack({img1, img2, img3}, “* h w”)

Parameters:
  • tensors – Input tensors

  • pattern – Pattern string with exactly one ‘*’ wildcard

Returns:

Pair of (packed tensor, packed_shapes for unpack)

std::vector<Tensor> unpack(const Tensor &tensor, const std::vector<Shape> &packed_shapes, const std::string &pattern)

Unpack a packed tensor back into individual tensors.

Example: auto tensors = unpack(packed, ps, “* h w”)

Parameters:
  • tensor – Packed tensor

  • packed_shapes – Shapes from pack() for each tensor’s wildcard dims

  • pattern – Pattern string with exactly one ‘*’ wildcard

Returns:

Vector of unpacked tensors

struct SimpleAxis
#include <einops.hpp>

Public Members

std::string name
struct GroupedAxes
#include <einops.hpp>

Public Members

std::vector<std::string> axes
std::map<std::string, size_t> anonymous_sizes
struct UnityAxis
#include <einops.hpp>
struct EllipsisAxis
#include <einops.hpp>
struct ParsedPattern
#include <einops.hpp>

Public Members

std::vector<AxisElement> elements
class EinopsExpression
#include <einops.hpp>

Public Functions

EinopsExpression(const std::string &pattern, const std::map<std::string, size_t> &axis_sizes = {})
Tensor apply(const Tensor &tensor) const
void validate_input(const Tensor &tensor) const
Shape get_output_shape(const Tensor &input) const
ParsedPattern parse_single_pattern(const std::string &pattern) const
std::vector<std::string> get_pattern_axes(const ParsedPattern &pattern) const
std::map<std::string, size_t> infer_axis_sizes(const Tensor &tensor) const
inline const ParsedPattern &parsed_input() const
inline const ParsedPattern &parsed_output() const
class EinopsError : public std::runtime_error
#include <einops.hpp>

Subclassed by axiom::einops::EinopsParseError, axiom::einops::EinopsShapeError

Public Functions

inline explicit EinopsError(const std::string &message)
class EinopsParseError : public axiom::einops::EinopsError
#include <einops.hpp>

Public Functions

inline explicit EinopsParseError(const std::string &message)
class EinopsShapeError : public axiom::einops::EinopsError
#include <einops.hpp>

Public Functions

inline explicit EinopsShapeError(const std::string &message)