C ++ 20: я — матрица, вещание, np.dot () и np.matmul ()

Я улучшил свой проект N-мерной матрицы C ++ 20 (C ++ 20: N-мерный минимальный класс Matrix).

Реализовано общее сложение / вычитание матриц, поэлементное умножение / деление, скалярное произведение, матричное произведение, изменение формы, транспонирование.

Там много кода:

ObjectBase.h

#ifndef FROZENCA_OBJECTBASE_H
#define FROZENCA_OBJECTBASE_H

#include <functional>
#include <utility>
#include "MatrixUtils.h"

namespace frozenca {

template <typename Derived>
class ObjectBase {
private:
    Derived& self() { return static_cast<Derived&>(*this); }
    const Derived& self() const { return static_cast<const Derived&>(*this); }

protected:
    ObjectBase() = default;
    ~ObjectBase() noexcept = default;

public:
    auto begin() { return self().begin(); }
    auto begin() const { return self().begin(); }
    auto cbegin() const { return self().cbegin(); }
    auto end() { return self().end(); }
    auto end() const { return self().end(); }
    auto cend() const { return self().cend(); }
    auto rbegin() { return self().rbegin(); }
    auto rbegin() const { return self().rbegin(); }
    auto crbegin() const { return self().crbegin(); }
    auto rend() { return self().rend(); }
    auto rend() const { return self().rend(); }
    auto crend() const { return self().crend(); }

    template <typename F> requires std::invocable<F, typename Derived::reference>
    ObjectBase& applyFunction(F&& f);

    template <typename F, typename... Args> requires std::invocable<F, typename Derived::reference, Args...>
    ObjectBase& applyFunction(F&& f, Args&&... args);

    template <typename DerivedOther, typename F> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference>
    ObjectBase& applyFunction(const ObjectBase<DerivedOther>& other, F&& f);

    template <typename DerivedOther, typename F, typename... Args> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference, Args...>
    ObjectBase& applyFunction(const ObjectBase<DerivedOther>& other, F&& f, Args&&... args);

    template <isNotMatrix U> requires Addable<typename Derived::value_type, U>
    ObjectBase& operator=(const U& val) {
        return applyFunction([&val](auto& v) {v = val;});
    }

    template <isNotMatrix U> requires Addable<typename Derived::value_type, U>
    ObjectBase& operator+=(const U& val) {
        return applyFunction([&val](auto& v) {v += val;});
    }

    template <isNotMatrix U> requires Subtractable<typename Derived::value_type, U>
    ObjectBase& operator-=(const U& val) {
        return applyFunction([&val](auto& v) {v -= val;});
    }

    template <isNotMatrix U> requires Multipliable<typename Derived::value_type, U>
    ObjectBase& operator*=(const U& val) {
        return applyFunction([&val](auto& v) {v *= val;});
    }

    template <isNotMatrix U> requires Dividable<typename Derived::value_type, U>
    ObjectBase& operator/=(const U& val) {
        return applyFunction([&val](auto& v) {v /= val;});
    }

    template <isNotMatrix U> requires Remaindable<typename Derived::value_type, U>
    ObjectBase& operator%=(const U& val) {
        return applyFunction([&val](auto& v) {v %= val;});
    }

    template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
    ObjectBase& operator&=(const U& val) {
        return applyFunction([&val](auto& v) {v &= val;});
    }

    template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
    ObjectBase& operator|=(const U& val) {
        return applyFunction([&val](auto& v) {v |= val;});
    }

    template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
    ObjectBase& operator^=(const U& val) {
        return applyFunction([&val](auto& v) {v ^= val;});
    }

    template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
    ObjectBase& operator<<=(const U& val) {
        return applyFunction([&val](auto& v) {v <<= val;});
    }

    template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
    ObjectBase& operator>>=(const U& val) {
        return applyFunction([&val](auto& v) {v >>= val;});
    }
};

template <typename Derived>
template <typename F> requires std::invocable<F, typename Derived::reference>
ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(F&& f) {
    for (auto it = begin(); it != end(); ++it) {
        f(*it);
    }
    return *this;
}

template <typename Derived>
template <typename F, typename... Args> requires std::invocable<F, typename Derived::reference, Args...>
ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(F&& f, Args&&... args) {
    for (auto it = begin(); it != end(); ++it) {
        f(*it, std::forward<Args...>(args...));
    }
    return *this;
}

template <typename Derived>
template <typename DerivedOther, typename F> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference>
ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(const ObjectBase<DerivedOther>& other, F&& f) {
    for (auto it = begin(), it2 = other.begin(); it != end(); ++it, ++it2) {
        f(*it, *it2);
    }
    return *this;
}

template <typename Derived>
template <typename DerivedOther, typename F, typename... Args> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference, Args...>
ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(const ObjectBase<DerivedOther>& other, F&& f, Args&&... args) {
    for (auto it = begin(), it2 = other.begin(); it != end(); ++it, ++it2) {
        f(*it, *it2, std::forward<Args...>(args...));
    }
    return *this;
}

template <typename Derived, isNotMatrix U> requires Addable<typename Derived::value_type, U>
ObjectBase<Derived> operator+(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res += val;
    return res;
}

template <typename Derived, isNotMatrix U> requires Subtractable<typename Derived::value_type, U>
ObjectBase<Derived> operator-(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res -= val;
    return res;
}

template <typename Derived, isNotMatrix U> requires Multipliable<typename Derived::value_type, U>
ObjectBase<Derived> operator*(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res *= val;
    return res;
}

template <typename Derived, isNotMatrix U> requires Dividable<typename Derived::value_type, U>
ObjectBase<Derived> operator/(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res /= val;
    return res;
}

template <typename Derived, isNotMatrix U> requires Remaindable<typename Derived::value_type, U>
ObjectBase<Derived> operator%(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res %= val;
    return res;
}

template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator&(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res &= val;
    return res;
}

template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator^(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res ^= val;
    return res;
}

template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator|(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res |= val;
    return res;
}

template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator<<(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res <<= val;
    return res;
}

template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator>>(const ObjectBase<Derived>& m, const U& val) {
    ObjectBase<Derived> res = m;
    res >>= val;
    return res;
}

} // namespace frozenca

#endif //FROZENCA_OBJECTBASE_H

MatrixBase.h

#ifndef FROZENCA_MATRIXBASE_H
#define FROZENCA_MATRIXBASE_H

#include <numeric>
#include "ObjectBase.h"
#include "MatrixInitializer.h"

namespace frozenca {

template <std::semiregular T, std::size_t N>
class MatrixView;

template <typename Derived, std::semiregular T, std::size_t N>
class MatrixBase : public ObjectBase<MatrixBase<Derived, T, N>> {
    static_assert(N > 1);
public:
    static constexpr std::size_t ndim = N;

private:
    std::array<std::size_t, N> dims_;
    std::size_t size_;
    std::array<std::size_t, N> strides_;

public:
    MatrixBase() = delete;
    using Base = ObjectBase<MatrixBase<Derived, T, N>>;
    using Base::applyFunction;
    using Base::operator=;
    using Base::operator+=;
    using Base::operator-=;
    using Base::operator*=;
    using Base::operator/=;
    using Base::operator%=;

    Derived& self() { return static_cast<Derived&>(*this); }
    const Derived& self() const { return static_cast<const Derived&>(*this); }

protected:
    ~MatrixBase() noexcept = default;
    MatrixBase(const std::array<std::size_t, N>& dims);

    template <std::size_t M> requires (M < N)
    MatrixBase(const std::array<std::size_t, M>& dims);

    template <IndexType... Dims>
    explicit MatrixBase(Dims... dims);

    template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T>
    MatrixBase(const MatrixBase<DerivedOther, U, N>&);

    MatrixBase(typename MatrixInitializer<T, N>::type init);

public:
    template <typename U>
    MatrixBase(std::initializer_list<U>) = delete;

    template <typename U>
    MatrixBase& operator=(std::initializer_list<U>) = delete;

    using value_type = T;
    using reference = T&;
    using const_reference = const T&;
    using pointer = T*;

public:
    friend void swap(MatrixBase& a, MatrixBase& b) noexcept {
        std::swap(a.size_, b.size_);
        std::swap(a.dims_, b.dims_);
        std::swap(a.strides_, b.strides_);
    }

    auto begin() { return self().begin(); }
    auto begin() const { return self().begin(); }
    auto cbegin() const { return self().cbegin(); }
    auto end() { return self().end(); }
    auto end() const { return self().end(); }
    auto cend() const { return self().cend(); }
    auto rbegin() { return self().rbegin(); }
    auto rbegin() const { return self().rbegin(); }
    auto crbegin() const { return self().crbegin(); }
    auto rend() { return self().rend(); }
    auto rend() const { return self().rend(); }
    auto crend() const { return self().crend(); }

    template <IndexType... Args>
    reference operator()(Args... args);

    template <IndexType... Args>
    const_reference operator()(Args... args) const;

    reference operator[](const std::array<std::size_t, N>& pos);
    const_reference operator[](const std::array<std::size_t, N>& pos) const;

    [[nodiscard]] std::size_t size() const { return size_;}

    [[nodiscard]] const std::array<std::size_t, N>& dims() const {
        return dims_;
    }

    [[nodiscard]] std::size_t dims(std::size_t n) const {
        if (n >= N) {
            throw std::out_of_range("Out of range in dims");
        }
        return dims_[n];
    }

    [[nodiscard]] const std::array<std::size_t, N>& strides() const {
        return strides_;
    }

    [[nodiscard]] std::size_t strides(std::size_t n) const {
        if (n >= N) {
            throw std::out_of_range("Out of range in strides");
        }
        return strides_[n];
    }

    auto dataView() const {
        return self().dataView();
    }

    auto origStrides() const {
        return self().origStrides();
    }

    MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin);
    MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end);
    MatrixView<T, N - 1> row(std::size_t n);
    MatrixView<T, N - 1> col(std::size_t n);
    MatrixView<T, N - 1> operator[](std::size_t n) { return row(n); }

    MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin) const;
    MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end) const;
    MatrixView<T, N - 1> row(std::size_t n) const;
    MatrixView<T, N - 1> col(std::size_t n) const;
    MatrixView<T, N - 1> operator[](std::size_t n) const { return row(n); }

    friend std::ostream& operator<<(std::ostream& os, const MatrixBase& m) {
        os << '{';
        for (std::size_t i = 0; i != m.dims(0); ++i) {
            os << m[i];
            if (i + 1 != m.dims(0)) {
                os << ", ";
            }
        }
        return os << '}';
    }

    template <typename DerivedOther1, typename DerivedOther2,
            std::semiregular U, std::semiregular V,
            std::size_t N1, std::size_t N2,
            std::invocable<MatrixView<T, N - 1>&,
                    const MatrixView<U, std::min(N1, N - 1)>&,
                    const MatrixView<V, std::min(N2, N - 1)>&> F>
    requires (std::max(N1, N2) == N)
    MatrixBase& applyFunctionWithBroadcast(const MatrixBase<DerivedOther1, U, N1>& m1,
                                           const MatrixBase<DerivedOther2, V, N2>& m2,
                                           F&& f);

};

template <typename Derived, std::semiregular T, std::size_t N>
MatrixBase<Derived, T, N>::MatrixBase(const std::array<std::size_t, N>& dims) : dims_ {dims} {
    if (std::ranges::find(dims_, 0lu) != std::end(dims_)) {
        throw std::invalid_argument("Zero dimension not allowed");
    }
    size_ = std::accumulate(std::begin(dims_), std::end(dims_), 1lu, std::multiplies<>{});
    strides_ = computeStrides(dims_);
}

template <typename Derived, std::semiregular T, std::size_t N>
template <std::size_t M> requires (M < N)
MatrixBase<Derived, T, N>::MatrixBase(const std::array<std::size_t, M>& dims) : MatrixBase (prepend<N, M>(dims)) {}

template <typename Derived, std::semiregular T, std::size_t N>
template <IndexType... Dims>
MatrixBase<Derived, T, N>::MatrixBase(Dims... dims) : dims_{static_cast<std::size_t>(dims)...} {
    static_assert(sizeof...(Dims) == N);
    static_assert((std::is_integral_v<Dims> && ...));
    if (std::ranges::find(dims_, 0lu) != std::end(dims_)) {
        throw std::invalid_argument("Zero dimension not allowed");
    }
    size_ = std::accumulate(std::begin(dims_), std::end(dims_), 1lu, std::multiplies<>{});
    strides_ = computeStrides(dims_);
}

template <typename Derived, std::semiregular T, std::size_t N>
template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T>
MatrixBase<Derived, T, N>::MatrixBase(const MatrixBase<DerivedOther, U, N>& other) : MatrixBase(other.dims()) {}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixBase<Derived, T, N>::MatrixBase(typename MatrixInitializer<T, N>::type init) : MatrixBase(deriveDims<N>(init)) {}

template <typename Derived, std::semiregular T, std::size_t N>
template <IndexType... Args>
typename MatrixBase<Derived, T, N>::reference MatrixBase<Derived, T, N>::operator()(Args... args) {
    return const_cast<typename MatrixBase<Derived, T, N>::reference>(std::as_const(*this).operator()(args...));
}

template <typename Derived, std::semiregular T, std::size_t N>
template <IndexType... Args>
typename MatrixBase<Derived, T, N>::const_reference MatrixBase<Derived, T, N>::operator()(Args... args) const {
    static_assert(sizeof...(args) == N);
    std::array<std::size_t, N> pos {std::size_t(args)...};
    return operator[](pos);
}

template <typename Derived, std::semiregular T, std::size_t N>
typename MatrixBase<Derived, T, N>::reference MatrixBase<Derived, T, N>::operator[](const std::array<std::size_t, N>& pos) {
    return const_cast<typename MatrixBase<Derived, T, N>::reference>(std::as_const(*this).operator[](pos));
}

template <typename Derived, std::semiregular T, std::size_t N>
typename MatrixBase<Derived, T, N>::const_reference MatrixBase<Derived, T, N>::operator[](const std::array<std::size_t, N>& pos) const {
    if (!std::equal(std::cbegin(pos), std::cend(pos), std::cbegin(dims_), std::less<>{})) {
        throw std::out_of_range("Out of range in element access");
    }
    return *(cbegin() + std::inner_product(std::cbegin(pos), std::cend(pos), std::cbegin(strides_), 0lu));
}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin) {
    return submatrix(pos_begin, dims_);
}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin) const {
    return submatrix(pos_begin, dims_);
}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin,
                                                const std::array<std::size_t, N>& pos_end) {
    return std::as_const(*this).submatrix(pos_begin, pos_end);
}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin,
                                                      const std::array<std::size_t, N>& pos_end) const {
    if (!std::equal(std::cbegin(pos_begin), std::cend(pos_begin), std::cbegin(pos_end), std::less<>{})) {
        throw std::out_of_range("submatrix begin/end position error");
    }
    std::array<std::size_t, N> view_dims;
    std::transform(std::cbegin(pos_end), std::cend(pos_end), std::cbegin(pos_begin), std::begin(view_dims),
                   std::minus<>{});
    MatrixView<T, N> view(view_dims, const_cast<T*>(&operator[](pos_begin)), strides());
    return view;
}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N - 1> MatrixBase<Derived, T, N>::row(std::size_t n) {
    return std::as_const(*this).row(n);
}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N - 1> MatrixBase<Derived, T, N>::row(std::size_t n) const {
    const auto& orig_dims = dims();
    if (n >= orig_dims[0]) {
        throw std::out_of_range("row index error");
    }
    std::array<std::size_t, N - 1> row_dims;
    std::copy(std::cbegin(orig_dims) + 1, std::cend(orig_dims), std::begin(row_dims));
    std::array<std::size_t, N> pos_begin = {n, };
    std::array<std::size_t, N - 1> row_strides;

    std::array<std::size_t, N> orig_strides;
    if constexpr (std::is_same_v<Derived, MatrixView<T, N>>) {
        orig_strides = origStrides();
    } else {
        orig_strides = strides();
    }

    std::copy(std::cbegin(orig_strides) + 1, std::cend(orig_strides), std::begin(row_strides));
    MatrixView<T, N - 1> nth_row(row_dims, const_cast<T*>(&operator[](pos_begin)), row_strides);
    return nth_row;
}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N - 1> MatrixBase<Derived, T, N>::col(std::size_t n) {
    return std::as_const(*this).col(n);
}

template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N - 1> MatrixBase<Derived, T, N>::col(std::size_t n) const {
    const auto& orig_dims = dims();
    if (n >= orig_dims[N - 1]) {
        throw std::out_of_range("row index error");
    }
    std::array<std::size_t, N - 1> col_dims;
    std::copy(std::cbegin(orig_dims), std::cend(orig_dims) - 1, std::begin(col_dims));
    std::array<std::size_t, N> pos_begin = {0};
    pos_begin[N - 1] = n;
    std::array<std::size_t, N - 1> col_strides;

    std::array<std::size_t, N> orig_strides;
    if constexpr (std::is_same_v<Derived, MatrixView<T, N>>) {
        orig_strides = origStrides();
    } else {
        orig_strides = strides();
    }

    std::copy(std::cbegin(orig_strides), std::cend(orig_strides) - 1, std::begin(col_strides));
    MatrixView<T, N - 1> nth_col(col_dims, const_cast<T*>(&operator[](pos_begin)), col_strides);
    return nth_col;
}

template <typename Derived, std::semiregular T, std::size_t N>
template <typename DerivedOther1, typename DerivedOther2,
        std::semiregular U, std::semiregular V,
        std::size_t N1, std::size_t N2,
        std::invocable<MatrixView<T, N - 1>&,
                const MatrixView<U, std::min(N1, N - 1)>&,
                const MatrixView<V, std::min(N2, N - 1)>&> F>
requires (std::max(N1, N2) == N)
MatrixBase<Derived, T, N>& MatrixBase<Derived, T, N>::applyFunctionWithBroadcast(const MatrixBase<DerivedOther1, U, N1>& m1,
                                                                                 const MatrixBase<DerivedOther2, V, N2>& m2,
                                                                                 F&& f) {
    if constexpr (N1 == N) {
        if constexpr (N2 == N) {
            auto r = dims(0);
            auto r1 = m1.dims(0);
            auto r2 = m2.dims(0);
            if (r1 == r) {
                if (r2 == r) {
                    for (std::size_t i = 0; i < r; ++i) {
                        auto row = this->row(i);
                        f(row, m1.row(i), m2.row(i));
                    }
                } else { // r2 < r == r1
                    auto row2 = m2.row(0);
                    for (std::size_t i = 0; i < r; ++i) {
                        auto row = this->row(i);
                        f(row, m1.row(i), row2);
                    }
                }
            } else if (r2 == r) { // r1 < r == r2
                auto row1 = m1.row(0);
                for (std::size_t i = 0; i < r; ++i) {
                    auto row = this->row(i);
                    f(row, row1, m2.row(i));
                }
            } else {
                assert(0); // cannot happen
            }
        } else { // N2 < N == N1
            auto r = dims(0);
            assert(r == m1.dims(0));
            MatrixView<V, N2> view2 (m2);
            for (std::size_t i = 0; i < r; ++i) {
                auto row = this->row(i);
                f(row, m1.row(i), view2);
            }
        }
    } else if constexpr (N2 == N) { // N1 < N == N2
        auto r = dims(0);
        assert(r == m2.dims(0));
        MatrixView<U, N1> view1 (m1);
        for (std::size_t i = 0; i < r; ++i) {
            auto row = this->row(i);
            f(row, view1, m2.row(i));
        }
    } else {
        assert(0); // cannot happen
    }
    return *this;
}

template <typename Derived, std::semiregular T>
class MatrixBase<Derived, T, 1> : public ObjectBase<MatrixBase<Derived, T, 1>> {
public:
    static constexpr std::size_t ndim = 1;

private:
    std::size_t dims_;
    std::size_t strides_;

    Derived& self() { return static_cast<Derived&>(*this); }
    const Derived& self() const { return static_cast<const Derived&>(*this); }

public:
    MatrixBase() = delete;
    using Base = ObjectBase<MatrixBase<Derived, T, 1>>;
    using Base::applyFunction;
    using Base::operator=;
    using Base::operator+=;
    using Base::operator-=;
    using Base::operator*=;
    using Base::operator/=;
    using Base::operator%=;

protected:
    ~MatrixBase() noexcept = default;
    template <typename Dim> requires std::is_integral_v<Dim>
    explicit MatrixBase(Dim dim) : dims_(dim), strides_(1) {};

    template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T>
    MatrixBase(const MatrixBase<DerivedOther, U, 1>&);

    MatrixBase(typename MatrixInitializer<T, 1>::type init);

public:
    using value_type = T;
    using reference = T&;
    using const_reference = const T&;
    using pointer = T*;

public:
    friend void swap(MatrixBase& a, MatrixBase& b) noexcept {
        std::swap(a.size_, b.size_);
        std::swap(a.dims_, b.dims_);
        std::swap(a.strides_, b.strides_);
    }

    auto begin() { return self().begin(); }
    auto begin() const { return self().begin(); }
    auto cbegin() const { return self().cbegin(); }
    auto end() { return self().end(); }
    auto end() const { return self().end(); }
    auto cend() const { return self().cend(); }
    auto rbegin() { return self().rbegin(); }
    auto rbegin() const { return self().rbegin(); }
    auto crbegin() const { return self().crbegin(); }
    auto rend() { return self().rend(); }
    auto rend() const { return self().rend(); }
    auto crend() const { return self().crend(); }

    template <typename Dim> requires std::is_integral_v<Dim>
    reference operator()(Dim dim) {
        return operator[](dim);
    }

    template <typename Dim> requires std::is_integral_v<Dim>
    const_reference operator()(Dim dim) const {
        return operator[](dim);
    }

    [[nodiscard]] std::array<std::size_t, 1> dims() const {
        return {dims_};
    }

    [[nodiscard]] std::size_t dims(std::size_t n) const {
        if (n >= 1) {
            throw std::out_of_range("Out of range in dims");
        }
        return dims_;
    }

    [[nodiscard]] std::size_t strides() const {
        return strides_;
    }

    auto dataView() const {
        return self().dataView();
    }

    auto origStrides() const {
        return self().origStrides();
    }

    MatrixView<T, 1> submatrix(std::size_t pos_begin);
    MatrixView<T, 1> submatrix(std::size_t pos_begin, std::size_t pos_end);
    T& row(std::size_t n);
    T& col(std::size_t n);
    T& operator[](std::size_t n) { return *(begin() + n); }

    MatrixView<T, 1> submatrix(std::size_t pos_begin) const;
    MatrixView<T, 1> submatrix(std::size_t pos_begin, std::size_t pos_end) const;
    const T& row(std::size_t n) const;
    const T& col(std::size_t n) const;
    const T& operator[](std::size_t n) const { return *(cbegin() + n); }

    friend std::ostream& operator<<(std::ostream& os, const MatrixBase& m) {
        os << '{';
        for (std::size_t i = 0; i != m.dims_; ++i) {
            os << m[i];
            if (i + 1 != m.dims_) {
                os << ", ";
            }
        }
        return os << '}';
    }

    template <typename DerivedOther1, typename DerivedOther2,
            std::semiregular U, std::semiregular V,
            std::invocable<T&, const U&, const V&> F>
    MatrixBase& applyFunctionWithBroadcast(const frozenca::MatrixBase<DerivedOther1, U, 1>& m1,
                                           const frozenca::MatrixBase<DerivedOther2, V, 1>& m2,
                                           F&& f);

};

template <typename Derived, std::semiregular T>
MatrixBase<Derived, T, 1>::MatrixBase(typename MatrixInitializer<T, 1>::type init) : MatrixBase(deriveDims<1>(init)[0]) {
}

template <typename Derived, std::semiregular T>
MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin) {
    return submatrix(pos_begin, dims_);
}

template <typename Derived, std::semiregular T>
MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin) const {
    return submatrix(pos_begin, dims_);
}

template <typename Derived, std::semiregular T>
MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin,
                                        std::size_t pos_end) {
    return std::as_const(*this).submatrix(pos_begin, pos_end);
}

template <typename Derived, std::semiregular T>
MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin,
                                              std::size_t pos_end) const {
    if (pos_begin >= pos_end) {
        throw std::out_of_range("submatrix begin/end position error");
    }
    MatrixView<T, 1> view ({pos_end - pos_begin}, const_cast<T*>(&operator[](pos_begin)), {strides_});
    return view;
}

template <typename Derived, std::semiregular T>
T& MatrixBase<Derived, T, 1>::row(std::size_t n) {
    return const_cast<T&>(std::as_const(*this).row(n));
}

template <typename Derived, std::semiregular T>
const T& MatrixBase<Derived, T, 1>::row(std::size_t n) const {
    if (n >= dims_) {
        throw std::out_of_range("row index error");
    }
    const T& val = operator[](n);
    return val;
}

template <typename Derived, std::semiregular T>
T& MatrixBase<Derived, T, 1>::col(std::size_t n) {
    return row(n);
}

template <typename Derived, std::semiregular T>
const T& MatrixBase<Derived, T, 1>::col(std::size_t n) const {
    return row(n);
}

template <typename Derived, std::semiregular T>
template <typename DerivedOther1, typename DerivedOther2,
        std::semiregular U, std::semiregular V,
        std::invocable<T&, const U&, const V&> F>
MatrixBase<Derived, T, 1>& MatrixBase<Derived, T, 1>::applyFunctionWithBroadcast(
        const frozenca::MatrixBase<DerivedOther1, U, 1>& m1,
        const frozenca::MatrixBase<DerivedOther2, V, 1>& m2,
        F&& f) {
    // real update is done here by passing lvalue reference T&
    auto r = dims(0);
    auto r1 = m1.dims(0);
    auto r2 = m2.dims(0);

    if (r1 == r) {
        if (r2 == r) {
            for (std::size_t i = 0; i < r; ++i) {
                f(this->row(i), m1.row(i), m2.row(i));
            }
        } else { // r2 < r == r1
            auto row2 = m2.row(0);
            for (std::size_t i = 0; i < r; ++i) {
                f(this->row(i), m1.row(i), row2);
            }
        }
    } else if (r2 == r) { // r1 < r == r2
        auto row1 = m1.row(0);
        for (std::size_t i = 0; i < r; ++i) {
            f(this->row(i), row1, m2.row(i));
        }
    }
    return *this;
}

} // namespace frozenca

#endif //FROZENCA_MATRIXBASE_H

(Stackexchange говорит, что OP слишком длинный, поэтому я заменяю два файла как ссылки)

MatrixImpl.h (https://github.com/frozenca/Ndim-Matrix/blob/main/MatrixImpl.h)

MatrixView.h (https://github.com/frozenca/Ndim-Matrix/blob/main/MatrixView.h)

MatrixUtils.h

#ifndef FROZENCA_MATRIXUTILS_H
#define FROZENCA_MATRIXUTILS_H

#include <algorithm>
#include <array>
#include <cassert>
#include <concepts>
#include <cstddef>
#include <initializer_list>
#include <iostream>
#include <iterator>
#include <memory>
#include <stdexcept>
#include <type_traits>

namespace frozenca {

template <std::semiregular T, std::size_t N>
class Matrix;

template <std::semiregular T, std::size_t N>
class MatrixView;

template <typename Derived, std::semiregular T, std::size_t N>
class MatrixBase;

template <typename Derived>
class ObjectBase;

template <typename T>
constexpr bool NotMatrix = true;

template <std::semiregular T, std::size_t N>
constexpr bool NotMatrix<Matrix<T, N>> = false;

template <std::semiregular T, std::size_t N>
constexpr bool NotMatrix<MatrixView<T, N>> = false;

template <typename Derived, std::semiregular T, std::size_t N>
constexpr bool NotMatrix<MatrixBase<Derived, T, N>> = false;

template <typename Derived>
constexpr bool NotMatrix<ObjectBase<Derived>> = false;

template <typename T>
concept isNotMatrix = NotMatrix<T> && std::semiregular<T>;

template <typename T>
concept isMatrix = !NotMatrix<T>;

template <typename T>
concept OneExists = requires () {
    { T{0} } -> std::convertible_to<T>;
    { T{1} } -> std::convertible_to<T>;
};

template <typename A, typename B>
concept WeakAddable = requires (A a, B b) {
    a + b;
};

template <typename A, typename B>
concept WeakSubtractable = requires (A a, B b) {
    a - b;
};

template <typename A, typename B>
concept WeakMultipliable = requires (A a, B b) {
    a * b;
};

template <typename A, typename B>
concept WeakDividable = requires (A a, B b) {
    a / b;
};

template <typename A, typename B>
concept WeakRemaindable = requires (A a, B b) {
    a / b;
    a % b;
};

template <typename A, typename B, typename C>
concept AddableTo = requires (A a, B b) {
    { a + b } -> std::convertible_to<C>;
};

template <typename A, typename B, typename C>
concept SubtractableTo = requires (A a, B b) {
    { a - b } -> std::convertible_to<C>;
};

template <typename A, typename B, typename C>
concept MultipliableTo = requires (A a, B b) {
    { a * b } -> std::convertible_to<C>;
};

template <typename A, typename B, typename C>
concept DividableTo = requires (A a, B b) {
    { a / b } -> std::convertible_to<C>;
};

template <typename A, typename B, typename C>
concept RemaindableTo = requires (A a, B b) {
    { a / b } -> std::convertible_to<C>;
    { a % b } -> std::convertible_to<C>;
};

template <typename A, typename B, typename C>
concept BitMaskableTo = requires (A a, B b) {
    { a & b } -> std::convertible_to<C>;
    { a | b } -> std::convertible_to<C>;
    { a ^ b } -> std::convertible_to<C>;
    { a << b } -> std::convertible_to<C>;
    { a >> b } -> std::convertible_to<C>;
};

template <typename A, typename B>
concept Addable = AddableTo<A, B, A>;

template <typename A, typename B>
concept Subtractable = SubtractableTo<A, B, A>;

template <typename A, typename B>
concept Multipliable = MultipliableTo<A, B, A>;

template <typename A, typename B>
concept Dividable = DividableTo<A, B, A>;

template <typename A, typename B>
concept Remaindable = RemaindableTo<A, B, A>;

template <typename A, typename B>
concept BitMaskable = BitMaskableTo<A, B, A>;

template <typename A, typename B> requires WeakAddable<A, B>
inline decltype(auto) Plus(A a, B b) {
    return a + b;
}

template <typename A, typename B> requires WeakSubtractable<A, B>
inline decltype(auto) Minus(A a, B b) {
    return a - b;
}

template <typename A, typename B> requires WeakMultipliable<A, B>
inline decltype(auto) Multiplies(A a, B b) {
    return a * b;
}

template <typename A, typename B> requires WeakDividable<A, B>
inline decltype(auto) Divides(A a, B b) {
    return a / b;
}

template <typename A, typename B> requires WeakRemaindable<A, B>
inline decltype(auto) Modulus(A a, B b) {
    return a % b;
}

template <typename A, typename B>
using AddType = std::invoke_result_t<decltype(Plus<A, B>), A, B>;

template <typename A, typename B>
using SubType = std::invoke_result_t<decltype(Minus<A, B>), A, B>;

template <typename A, typename B>
using MulType = std::invoke_result_t<decltype(Multiplies<A, B>), A, B>;

template <typename A, typename B>
using DivType = std::invoke_result_t<decltype(Divides<A, B>), A, B>;

template <typename A, typename B>
using ModType = std::invoke_result_t<decltype(Modulus<A, B>), A, B>;

template <typename A, typename B>
concept DotProductable = Addable<MulType<A, B>, MulType<A, B>>;

template <typename A, typename B, typename C>
concept DotProductableTo = DotProductable<A, B> && MultipliableTo<A, B, C> && Addable<C, C>;

template <typename A, typename B, typename C> requires AddableTo<A, B, C>
inline void PlusTo(C& c, const A& a, const B& b) {
    c = a + b;
}

template <typename A, typename B, typename C> requires SubtractableTo<A, B, C>
inline void MinusTo(C& c, const A& a, const B& b) {
    c = a - b;
}

template <typename A, typename B, typename C> requires MultipliableTo<A, B, C>
inline void MultipliesTo(C& c, const A& a, const B& b) {
    c = a * b;
}

template <typename A, typename B, typename C> requires DividableTo<A, B, C>
inline void DividesTo(C& c, const A& a, const B& b) {
    c = a / b;
}

template <typename A, typename B, typename C> requires RemaindableTo<A, B, C>
inline void ModulusTo(C& c, const A& a, const B& b) {
    c = a % b;
}

template <typename... Args>
inline constexpr bool All(Args... args) { return (... && args); };

template <typename... Args>
inline constexpr bool Some(Args... args) { return (... || args); };

template <std::size_t M, std::size_t N> requires (N < M)
std::array<std::size_t, M> prependDims(const std::array<std::size_t, N>& arr) {
    std::array<std::size_t, M> dims;
    std::ranges::fill(dims, 1u);
    std::ranges::copy(arr, std::begin(dims) + (M - N));
    return dims;
}

template <std::size_t M, std::size_t N>
bool bidirBroadcastable(const std::array<std::size_t, M>& sz1,
                        const std::array<std::size_t, N>& sz2) {
    if constexpr (M == N) {
        return (std::ranges::equal(sz1, sz2, [](const auto& d1, const auto& d2) {
            return (d1 == d2) || (d1 == 1) || (d2 == 1);}));
    } else if constexpr (M < N) {
        return bidirBroadcastable(prependDims<N, M>(sz1), sz2);
    } else {
        static_assert(M > N);
        return bidirBroadcastable(sz1, prependDims<M, N>(sz2));
    }
}

template <std::size_t M, std::size_t N>
std::array<std::size_t, std::max(M, N)> bidirBroadcastedDims(const std::array<std::size_t, M>& sz1,
                                                             const std::array<std::size_t, N>& sz2) {
    if constexpr (M == N) {
        if (!bidirBroadcastable(sz1, sz2)) {
            throw std::invalid_argument("Cannot broadcast");
        }
        std::array<std::size_t, M> sz;
        std::ranges::transform(sz1, sz2, std::begin(sz), [](const auto& d1, const auto& d2) {
            return std::max(d1, d2);
        });
        return sz;
    } else if constexpr (M < N) {
        return bidirBroadcastedDims(prependDims<N, M>(sz1), sz2);
    } else {
        static_assert(M > N);
        return bidirBroadcastedDims(sz1, prependDims<M, N>(sz2));
    }
}

template <std::size_t M> requires (M > 1)
std::array<std::size_t, M - 1> dotDims(const std::array<std::size_t, M>& sz1,
                                       const std::array<std::size_t, 1>& sz2) {
    if (sz1[M - 1] != sz2[0]) {
        throw std::invalid_argument("Cannot do dot product, shape is not aligned");
    }
    std::array<std::size_t, M - 1> sz;
    std::copy(std::begin(sz1), std::begin(sz1) + (M - 1), std::begin(sz));
    return sz;
}

template <std::size_t M, std::size_t N> requires (N > 1)
std::array<std::size_t, M + N - 2> dotDims(const std::array<std::size_t, M>& sz1,
                                           const std::array<std::size_t, N>& sz2) {
    if (sz1[M - 1] != sz2[N - 2]) {
        throw std::invalid_argument("Cannot do dot product, shape is not aligned");
    }
    std::array<std::size_t, M + N - 2> sz;
    std::copy(std::begin(sz1), std::begin(sz1) + (M - 1), std::begin(sz));
    std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(sz) + (M - 1));
    std::copy(std::begin(sz2) + (N - 1), std::end(sz2), std::begin(sz) + (M + N - 3));
    return sz;
}

template <std::size_t M, std::size_t N>
std::array<std::size_t, std::max(M, N)> matmulDims(const std::array<std::size_t, M>& sz1,
                                                   const std::array<std::size_t, N>& sz2) {
    if constexpr (M == 1) {
        std::array<std::size_t, 2> sz1_ = {1, sz1[0]};
        return matmulDims(sz1_, sz2);
    } else if constexpr (N == 1) {
        std::array<std::size_t, 2> sz2_ = {sz2[0], 1};
        return matmulDims(sz1, sz2_);
    }
    assert(M >= 2 && N >= 2);
    if (sz1[M - 1] != sz2[N - 2]) {
        throw std::invalid_argument("Cannot do dot product, shape is not aligned");
    }
    std::array<std::size_t, 2> last_sz = {sz1[M - 2], sz2[N - 1]};
    if constexpr (M == 2) {
        if constexpr (N == 2) {
            return last_sz;
        } else { // M = 2, N > 2
            std::array<std::size_t, N> res_sz;
            std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(res_sz));
            std::copy(std::begin(last_sz), std::end(last_sz), std::begin(res_sz) + (N - 2));
            return res_sz;
        }
    } else if constexpr (N == 2) { // M > 2, N = 2
        std::array<std::size_t, M> res_sz;
        std::copy(std::begin(sz1), std::begin(sz2) + (M - 2), std::begin(res_sz));
        std::copy(std::begin(last_sz), std::end(last_sz), std::begin(res_sz) + (M - 2));
        return res_sz;
    } else { // M > 2, N > 2
        std::array<std::size_t, std::max(M, N)> res_sz;
        std::array<std::size_t, M - 2> sz1_front;
        std::array<std::size_t, N - 2> sz2_front;
        std::copy(std::begin(sz1), std::begin(sz1) + (M - 2), std::begin(sz1_front));
        std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(sz2_front));
        auto common_sz = bidirBroadcastedDims(sz1_front, sz2_front);
        std::copy(std::begin(common_sz), std::end(common_sz), std::begin(res_sz));
        std::copy(std::begin(last_sz), std::end(last_sz), std::end(res_sz) - 2);
        return res_sz;
    }

}

template <typename... Args>
concept IndexType = All(std::is_integral_v<Args>...);

template <std::size_t N>
std::array<std::size_t, N> computeStrides(const std::array<std::size_t, N>& dims) {
    std::array<std::size_t, N> strides;
    std::size_t str = 1;
    for (std::size_t i = N - 1; i < N; --i) {
        strides[i] = str;
        str *= dims[i];
    }
    return strides;
}

template <std::size_t N, typename Initializer>
bool checkNonJagged(const Initializer& init) {
    auto i = std::cbegin(init);
    for (auto j = std::next(i); j != std::cend(init); ++j) {
        if (i->size() != j->size()) {
            return false;
        }
    }
    return true;
}

template <std::size_t N, typename Iter, typename Initializer>
void addDims(Iter& first, const Initializer& init) {
    if constexpr (N > 1) {
        if (!checkNonJagged<N>(init)) {
            throw std::invalid_argument("Jagged matrix initializer");
        }
    }
    *first = std::size(init);
    ++first;
    if constexpr (N > 1) {
        addDims<N - 1>(first, *std::begin(init));
    }
}

template <std::size_t N, typename Initializer>
std::array<std::size_t, N> deriveDims(const Initializer& init) {
    std::array<std::size_t, N> dims;
    auto f = std::begin(dims);
    addDims<N>(f, init);
    return dims;
}

template <std::semiregular T>
void addList(std::unique_ptr<T[]>& data,
             const T* first, const T* last,
             std::size_t& index) {
    for (; first != last; ++first) {
        data[index] = *first;
        ++index;
    }
}

template <std::semiregular T, typename I>
void addList(std::unique_ptr<T[]>& data,
             const std::initializer_list<I>* first, const std::initializer_list<I>* last,
             std::size_t& index) {
    for (; first != last; ++first) {
        addList(data, first->begin(), first->end(), index);
    }
}

template <std::semiregular T, typename I>
void insertFlat(std::unique_ptr<T[]>& data, std::initializer_list<I> list) {
    std::size_t index = 0;
    addList(data, std::begin(list), std::end(list), index);
}

inline long quot(long a, long b) {
    return (a / b) - (a % b < 0);
}

inline long mod(long a, long b) {
    return (a % b + b) % b;
}

} // namespace frozenca

#endif //FROZENCA_MATRIXUTILS_H

MatrixInitializer.h

#ifndef FROZENCA_MATRIXINITIALIZER_H
#define FROZENCA_MATRIXINITIALIZER_H

#include <cstddef>
#include <concepts>
#include <initializer_list>

namespace frozenca {

template <std::semiregular T, std::size_t N>
struct MatrixInitializer {
    using type = std::initializer_list<typename MatrixInitializer<T, N - 1>::type>;
};

template <std::semiregular T>
struct MatrixInitializer<T, 1> {
    using type = std::initializer_list<T>;
};

template <std::semiregular T>
struct MatrixInitializer<T, 0>;

} // namespace frozenca

#endif //FROZENCA_MATRIXINITIALIZER_H

MatrixOps.h

#ifndef FROZENCA_MATRIXOPS_H
#define FROZENCA_MATRIXOPS_H

#include "MatrixImpl.h"

namespace frozenca {

// Matrix constructs

template <std::semiregular T, std::size_t N>
Matrix<T, N> empty(const std::array<std::size_t, N>& arr) {
    Matrix<T, N> mat (arr);
    return mat;
}

template <typename Derived, std::semiregular T, std::size_t N>
Matrix<T, N> empty_like(const MatrixBase<Derived, T, N>& base) {
    Matrix<T, N> mat (base.dims());
    return mat;
}

template <OneExists T>
Matrix<T, 2> eye(std::size_t n, std::size_t m) {
    Matrix<T, 2> mat (n, m);
    for (std::size_t i = 0; i < std::min(n, m); ++i) {
        mat(i, i) = T{1};
    }
    return mat;
}

template <OneExists T>
Matrix<T, 2> eye(std::size_t n) {
    return eye<T>(n, n);
}

template <OneExists T>
Matrix<T, 2> identity(std::size_t n) {
    return eye<T>(n, n);
}

template <OneExists T, std::size_t N>
Matrix<T, N> ones(const std::array<std::size_t, N>& arr) {
    Matrix<T, N> mat (arr);
    std::ranges::fill(mat, T{1});
    return mat;
}

template <typename Derived, OneExists T, std::size_t N>
Matrix<T, N> ones_like(const MatrixBase<Derived, T, N>& base) {
    Matrix<T, N> mat (base.dims());
    std::ranges::fill(mat, T{1});
    return mat;
}

template <std::semiregular T, std::size_t N>
Matrix<T, N> zeros(const std::array<std::size_t, N>& arr) {
    Matrix<T, N> mat (arr);
    std::ranges::fill(mat, T{0});
    return mat;
}

template <typename Derived, std::semiregular T, std::size_t N>
Matrix<T, N> zeros_like(const MatrixBase<Derived, T, N>& base) {
    Matrix<T, N> mat (base.dims());
    std::ranges::fill(mat, T{0});
    return mat;
}

template <std::semiregular T, std::size_t N>
Matrix<T, N> full(const std::array<std::size_t, N>& arr, const T& fill_value) {
    Matrix<T, N> mat (arr);
    std::ranges::fill(mat, fill_value);
    return mat;
}

template <typename Derived, std::semiregular T, std::size_t N>
Matrix<T, N> full_like(const MatrixBase<Derived, T, N>& base, const T& fill_value) {
    Matrix<T, N> mat (base.dims());
    std::ranges::fill(mat, fill_value);
    return mat;
}

// binary matrix operators

namespace {

template <std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2, std::size_t N>
requires AddableTo<U, V, T> && (std::max(N1, N2) == N)
void AddTo(MatrixView<T, N>& m,
           const MatrixView<U, N1>& m1,
           const MatrixView<V, N2>& m2) {
    if constexpr (N == 1) {
        m.applyFunctionWithBroadcast(m1, m2, PlusTo<U, V, T>);
    } else {
        m.applyFunctionWithBroadcast(m1, m2, AddTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    }
}

template <std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2, std::size_t N>
requires SubtractableTo<U, V, T> && (std::max(N1, N2) == N)
void SubtractTo(MatrixView<T, N>& m,
                const MatrixView<U, N1>& m1,
                const MatrixView<V, N2>& m2) {
    if constexpr (N == 1) {
        m.applyFunctionWithBroadcast(m1, m2, MinusTo<U, V, T>);
    } else {
        m.applyFunctionWithBroadcast(m1, m2, SubtractTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    }
}

template <std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2, std::size_t N>
requires MultipliableTo<U, V, T> && (std::max(N1, N2) == N)
void MultiplyTo(MatrixView<T, N>& m,
                const MatrixView<U, N1>& m1,
                const MatrixView<V, N2>& m2) {
    if constexpr (N == 1) {
        m.applyFunctionWithBroadcast(m1, m2, MultipliesTo<U, V, T>);
    } else {
        m.applyFunctionWithBroadcast(m1, m2, MultiplyTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    }
}

template <std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2, std::size_t N>
requires DividableTo<U, V, T> && (std::max(N1, N2) == N)
void DivideTo(MatrixView<T, N>& m,
              const MatrixView<U, N1>& m1,
              const MatrixView<V, N2>& m2) {
    if constexpr (N == 1) {
        m.applyFunctionWithBroadcast(m1, m2, DividesTo<U, V, T>);
    } else {
        m.applyFunctionWithBroadcast(m1, m2, DivideTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    }
}

template <std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2, std::size_t N>
requires RemaindableTo<U, V, T> && (std::max(N1, N2) == N)
void ModuloTo(MatrixView<T, N>& m,
              const MatrixView<U, N1>& m1,
              const MatrixView<V, N2>& m2) {
    if constexpr (N == 1) {
        m.applyFunctionWithBroadcast(m1, m2, ModulusTo<U, V, T>);
    } else {
        m.applyFunctionWithBroadcast(m1, m2, ModuloTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    }
}

} // anonymous namespace

template <typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V,
        std::size_t N1, std::size_t N2,
        std::semiregular T = AddType<U, V>> requires AddableTo<U, V, T>
decltype(auto) operator+ (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
    constexpr std::size_t N = std::max(N1, N2);
    auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
    Matrix<T, N> res = zeros<T, N>(dims);
    res.applyFunctionWithBroadcast(m1, m2, AddTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    return res;
}

template <typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V,
        std::size_t N1, std::size_t N2,
        std::semiregular T = SubType<U, V>> requires SubtractableTo<U, V, T>
decltype(auto) operator- (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
    constexpr std::size_t N = std::max(N1, N2);
    auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
    Matrix<T, N> res = zeros<T, N>(dims);
    res.applyFunctionWithBroadcast(m1, m2, SubtractTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    return res;
}

template <typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V,
        std::size_t N1, std::size_t N2,
        std::semiregular T = MulType<U, V>> requires MultipliableTo<U, V, T>
decltype(auto) operator* (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
    constexpr std::size_t N = std::max(N1, N2);
    auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
    Matrix<T, N> res = zeros<T, N>(dims);
    res.applyFunctionWithBroadcast(m1, m2, MultiplyTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    return res;
}

template <typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V,
        std::size_t N1, std::size_t N2,
        std::semiregular T = DivType<U, V>> requires DividableTo<U, V, T>
decltype(auto) operator/ (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
    constexpr std::size_t N = std::max(N1, N2);
    auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
    Matrix<T, N> res = zeros<T, N>(dims);
    res.applyFunctionWithBroadcast(m1, m2, DivideTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    return res;
}

template <typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V,
        std::size_t N1, std::size_t N2,
        std::semiregular T = ModType<U, V>> requires RemaindableTo<U, V, T>
decltype(auto) operator% (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
    constexpr std::size_t N = std::max(N1, N2);
    auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
    Matrix<T, N> res = zeros<T, N>(dims);
    res.applyFunctionWithBroadcast(m1, m2, ModuloTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    return res;
}



} // namespace frozenca

#endif //FROZENCA_MATRIXOPS_H

LinalgOps.h (ОПЕРАЦИИ ЛИНЕЙНОЙ АЛГЕБРЫ)

#ifndef FROZENCA_LINALGOPS_H
#define FROZENCA_LINALGOPS_H

#include "Matrix.h"

namespace frozenca {

namespace {

template <std::semiregular U, std::semiregular V, std::semiregular T>
requires DotProductableTo<U, V, T>
void DotTo(T& m,
           const MatrixView<U, 1>& m1,
           const MatrixView<V, 1>& m2) {
    m += std::inner_product(std::begin(m1), std::end(m1), std::begin(m2), T{0});
}

template <std::semiregular U, std::semiregular V, std::semiregular T>
requires DotProductableTo<U, V, T>
void DotTo(MatrixView<T, 1>& m,
           const MatrixView<U, 1>& m1,
           const MatrixView<V, 2>& m2) {
    assert(m.dims(0) == m2.dims(1));
    std::size_t c = m2.dims(1);
    for (std::size_t j = 0; j < c; ++j) {
        auto col2 = m2.col(j);
        m[j] += std::inner_product(std::begin(m1), std::end(m1), std::begin(col2), T{0});
    }
}

template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N2>
requires DotProductableTo<U, V, T> && (N2 > 2)
void DotTo(MatrixView<T, N2 - 1>& m,
           const MatrixView<U, 1>& m1,
           const MatrixView<V, N2>& m2) {
    assert(m.dims(0) == m2.dims(0));
    std::size_t r = m.dims(0);
    for (std::size_t i = 0; i < r; ++i) {
        auto row0 = m.row(i);
        auto row2 = m2.row(i);
        DotTo(row0, m1, row2);
    }
}

template <std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2>
requires DotProductableTo<U, V, T> && (N1 > 1)
void DotTo(MatrixView<T, N1 - 1>& m,
           const MatrixView<U, N1>& m1,
           const MatrixView<V, 1>& m2) {
    assert(m.dims(0) == m1.dims(0));
    std::size_t r = m.dims(0);
    for (std::size_t i = 0; i < r; ++i) {
        auto row0 = m.row(i);
        auto row1 = m1.row(i);
        DotTo(row0, row1, m2);
    }
}

template <std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2>
requires DotProductableTo<U, V, T> && (N1 > 1) && (N2 > 1)
void DotTo(MatrixView<T, N1 + N2 - 2>& m,
           const MatrixView<U, N1>& m1,
           const MatrixView<V, N2>& m2) {
    assert(m.dims(0) == m1.dims(0));
    std::size_t r = m.dims(0);
    for (std::size_t i = 0; i < r; ++i) {
        auto row0 = m.row(i);
        auto row1 = m1.row(i);
        DotTo(row0, row1, m2);
    }
}

template <typename Derived0, typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2>
requires DotProductableTo<U, V, T>
void DotTo(MatrixBase<Derived0, T, (N1 + N2 - 2)>& m,
           const MatrixBase<Derived1, U, N1>& m1,
           const MatrixBase<Derived2, V, N2>& m2) {
    MatrixView<T, (N1 + N2 - 2)> m_view (m);
    MatrixView<U, N1> m1_view (m1);
    MatrixView<V, N2> m2_view (m2);
    DotTo(m_view, m1_view, m2_view);
}

template <std::semiregular U, std::semiregular V, std::semiregular T,
        std::size_t N1, std::size_t N2, std::size_t N>
requires DotProductableTo<U, V, T> && (std::max(N1, N2) == N)
void MatmulTo(MatrixView<T, N>& m,
           const MatrixView<U, N1>& m1,
           const MatrixView<V, N2>& m2) {
    if constexpr (N == 2) {
        DotTo(m, m1, m2);
    } else {
        m.applyFunctionWithBroadcast(m1, m2, MatmulTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    }
}

} // anonymous namespace

template <typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V,
        std::size_t M, std::size_t N,
        std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T>
decltype(auto) dot(const MatrixBase<Derived1, U, M>& m1, const MatrixBase<Derived2, V, N>& m2) {
    auto dims = dotDims(m1.dims(), m2.dims());
    Matrix<T, (M + N - 2)> res = zeros<T, (M + N - 2)>(dims);
    DotTo(res, m1, m2);
    return res;
}

template <typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V,
        std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T>
decltype(auto) dot(const MatrixBase<Derived1, U, 1>& m1, const MatrixBase<Derived2, V, 1>& m2) {
    auto dims = dotDims(m1.dims(), m2.dims());
    T res {0};
    DotTo(res, m1, m2);
    return res;
}

template <typename Derived1, typename Derived2,
        std::semiregular U, std::semiregular V,
        std::size_t N1, std::size_t N2,
        std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T>
decltype(auto) matmul(const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
    constexpr std::size_t N = std::max(N1, N2);
    auto dims = matmulDims(m1.dims(), m2.dims());
    Matrix<T, N> res = zeros<T, N>(dims);
    res.applyFunctionWithBroadcast(m1, m2, MatmulTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
    return res;
}

} // namespace frozenca

#endif //FROZENCA_LINALGOPS_H

Matrix.h

#ifndef FROZENCA_MATRIX_H
#define FROZENCA_MATRIX_H

#include "MatrixImpl.h"
#include "MatrixOps.h"
#include "LinalgOps.h"

#endif //FROZENCA_MATRIX_H

Не стесняйтесь комментировать что угодно!

Что мне не нравится:

  • Слишком много API для пользователей. Я ненавижу C ++ includes, очень жду модулей C ++ 20! (Текущая реализация модуля в MSVC 19.28 не работает, не работает в моем коде)
  • .reshape (). Он должен переместить буфер, поэтому он принимает ссылку на rvalue, но мне интересно что-то получше.
  • Вспомогательные функции становятся слишком загруженными шаблонами. Точнее, мне не нравятся такие функции, как applyFunctionWithBroadcasting генерируются по-разному с каждым разным N> 1, даже если то, что они делают, одинаково. Но мне нужна специализация для N = 1 (потому что подматрица, строка, столбец должны возвращать T&), поэтому я не знаю лучшего способа.

Чтобы увидеть, что происходит с «широковещанием», скалярным произведением и матричным умножением, см. Ниже:

https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md

https://numpy.org/doc/stable/reference/generated/numpy.dot.html

https://numpy.org/doc/stable/reference/generated/numpy.matmul.html

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *