Я улучшил свой проект N-мерной матрицы C ++ 20 (C ++ 20: N-мерный минимальный класс Matrix).
Реализовано общее сложение / вычитание матриц, поэлементное умножение / деление, скалярное произведение, матричное произведение, изменение формы, транспонирование.
Там много кода:
ObjectBase.h
#ifndef FROZENCA_OBJECTBASE_H
#define FROZENCA_OBJECTBASE_H
#include <functional>
#include <utility>
#include "MatrixUtils.h"
namespace frozenca {
template <typename Derived>
class ObjectBase {
private:
Derived& self() { return static_cast<Derived&>(*this); }
const Derived& self() const { return static_cast<const Derived&>(*this); }
protected:
ObjectBase() = default;
~ObjectBase() noexcept = default;
public:
auto begin() { return self().begin(); }
auto begin() const { return self().begin(); }
auto cbegin() const { return self().cbegin(); }
auto end() { return self().end(); }
auto end() const { return self().end(); }
auto cend() const { return self().cend(); }
auto rbegin() { return self().rbegin(); }
auto rbegin() const { return self().rbegin(); }
auto crbegin() const { return self().crbegin(); }
auto rend() { return self().rend(); }
auto rend() const { return self().rend(); }
auto crend() const { return self().crend(); }
template <typename F> requires std::invocable<F, typename Derived::reference>
ObjectBase& applyFunction(F&& f);
template <typename F, typename... Args> requires std::invocable<F, typename Derived::reference, Args...>
ObjectBase& applyFunction(F&& f, Args&&... args);
template <typename DerivedOther, typename F> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference>
ObjectBase& applyFunction(const ObjectBase<DerivedOther>& other, F&& f);
template <typename DerivedOther, typename F, typename... Args> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference, Args...>
ObjectBase& applyFunction(const ObjectBase<DerivedOther>& other, F&& f, Args&&... args);
template <isNotMatrix U> requires Addable<typename Derived::value_type, U>
ObjectBase& operator=(const U& val) {
return applyFunction([&val](auto& v) {v = val;});
}
template <isNotMatrix U> requires Addable<typename Derived::value_type, U>
ObjectBase& operator+=(const U& val) {
return applyFunction([&val](auto& v) {v += val;});
}
template <isNotMatrix U> requires Subtractable<typename Derived::value_type, U>
ObjectBase& operator-=(const U& val) {
return applyFunction([&val](auto& v) {v -= val;});
}
template <isNotMatrix U> requires Multipliable<typename Derived::value_type, U>
ObjectBase& operator*=(const U& val) {
return applyFunction([&val](auto& v) {v *= val;});
}
template <isNotMatrix U> requires Dividable<typename Derived::value_type, U>
ObjectBase& operator/=(const U& val) {
return applyFunction([&val](auto& v) {v /= val;});
}
template <isNotMatrix U> requires Remaindable<typename Derived::value_type, U>
ObjectBase& operator%=(const U& val) {
return applyFunction([&val](auto& v) {v %= val;});
}
template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase& operator&=(const U& val) {
return applyFunction([&val](auto& v) {v &= val;});
}
template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase& operator|=(const U& val) {
return applyFunction([&val](auto& v) {v |= val;});
}
template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase& operator^=(const U& val) {
return applyFunction([&val](auto& v) {v ^= val;});
}
template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase& operator<<=(const U& val) {
return applyFunction([&val](auto& v) {v <<= val;});
}
template <isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase& operator>>=(const U& val) {
return applyFunction([&val](auto& v) {v >>= val;});
}
};
template <typename Derived>
template <typename F> requires std::invocable<F, typename Derived::reference>
ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(F&& f) {
for (auto it = begin(); it != end(); ++it) {
f(*it);
}
return *this;
}
template <typename Derived>
template <typename F, typename... Args> requires std::invocable<F, typename Derived::reference, Args...>
ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(F&& f, Args&&... args) {
for (auto it = begin(); it != end(); ++it) {
f(*it, std::forward<Args...>(args...));
}
return *this;
}
template <typename Derived>
template <typename DerivedOther, typename F> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference>
ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(const ObjectBase<DerivedOther>& other, F&& f) {
for (auto it = begin(), it2 = other.begin(); it != end(); ++it, ++it2) {
f(*it, *it2);
}
return *this;
}
template <typename Derived>
template <typename DerivedOther, typename F, typename... Args> requires std::invocable<F, typename Derived::reference, typename DerivedOther::reference, Args...>
ObjectBase<Derived>& ObjectBase<Derived>::applyFunction(const ObjectBase<DerivedOther>& other, F&& f, Args&&... args) {
for (auto it = begin(), it2 = other.begin(); it != end(); ++it, ++it2) {
f(*it, *it2, std::forward<Args...>(args...));
}
return *this;
}
template <typename Derived, isNotMatrix U> requires Addable<typename Derived::value_type, U>
ObjectBase<Derived> operator+(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res += val;
return res;
}
template <typename Derived, isNotMatrix U> requires Subtractable<typename Derived::value_type, U>
ObjectBase<Derived> operator-(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res -= val;
return res;
}
template <typename Derived, isNotMatrix U> requires Multipliable<typename Derived::value_type, U>
ObjectBase<Derived> operator*(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res *= val;
return res;
}
template <typename Derived, isNotMatrix U> requires Dividable<typename Derived::value_type, U>
ObjectBase<Derived> operator/(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res /= val;
return res;
}
template <typename Derived, isNotMatrix U> requires Remaindable<typename Derived::value_type, U>
ObjectBase<Derived> operator%(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res %= val;
return res;
}
template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator&(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res &= val;
return res;
}
template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator^(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res ^= val;
return res;
}
template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator|(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res |= val;
return res;
}
template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator<<(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res <<= val;
return res;
}
template <typename Derived, isNotMatrix U> requires BitMaskable<typename Derived::value_type, U>
ObjectBase<Derived> operator>>(const ObjectBase<Derived>& m, const U& val) {
ObjectBase<Derived> res = m;
res >>= val;
return res;
}
} // namespace frozenca
#endif //FROZENCA_OBJECTBASE_H
MatrixBase.h
#ifndef FROZENCA_MATRIXBASE_H
#define FROZENCA_MATRIXBASE_H
#include <numeric>
#include "ObjectBase.h"
#include "MatrixInitializer.h"
namespace frozenca {
template <std::semiregular T, std::size_t N>
class MatrixView;
template <typename Derived, std::semiregular T, std::size_t N>
class MatrixBase : public ObjectBase<MatrixBase<Derived, T, N>> {
static_assert(N > 1);
public:
static constexpr std::size_t ndim = N;
private:
std::array<std::size_t, N> dims_;
std::size_t size_;
std::array<std::size_t, N> strides_;
public:
MatrixBase() = delete;
using Base = ObjectBase<MatrixBase<Derived, T, N>>;
using Base::applyFunction;
using Base::operator=;
using Base::operator+=;
using Base::operator-=;
using Base::operator*=;
using Base::operator/=;
using Base::operator%=;
Derived& self() { return static_cast<Derived&>(*this); }
const Derived& self() const { return static_cast<const Derived&>(*this); }
protected:
~MatrixBase() noexcept = default;
MatrixBase(const std::array<std::size_t, N>& dims);
template <std::size_t M> requires (M < N)
MatrixBase(const std::array<std::size_t, M>& dims);
template <IndexType... Dims>
explicit MatrixBase(Dims... dims);
template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T>
MatrixBase(const MatrixBase<DerivedOther, U, N>&);
MatrixBase(typename MatrixInitializer<T, N>::type init);
public:
template <typename U>
MatrixBase(std::initializer_list<U>) = delete;
template <typename U>
MatrixBase& operator=(std::initializer_list<U>) = delete;
using value_type = T;
using reference = T&;
using const_reference = const T&;
using pointer = T*;
public:
friend void swap(MatrixBase& a, MatrixBase& b) noexcept {
std::swap(a.size_, b.size_);
std::swap(a.dims_, b.dims_);
std::swap(a.strides_, b.strides_);
}
auto begin() { return self().begin(); }
auto begin() const { return self().begin(); }
auto cbegin() const { return self().cbegin(); }
auto end() { return self().end(); }
auto end() const { return self().end(); }
auto cend() const { return self().cend(); }
auto rbegin() { return self().rbegin(); }
auto rbegin() const { return self().rbegin(); }
auto crbegin() const { return self().crbegin(); }
auto rend() { return self().rend(); }
auto rend() const { return self().rend(); }
auto crend() const { return self().crend(); }
template <IndexType... Args>
reference operator()(Args... args);
template <IndexType... Args>
const_reference operator()(Args... args) const;
reference operator[](const std::array<std::size_t, N>& pos);
const_reference operator[](const std::array<std::size_t, N>& pos) const;
[[nodiscard]] std::size_t size() const { return size_;}
[[nodiscard]] const std::array<std::size_t, N>& dims() const {
return dims_;
}
[[nodiscard]] std::size_t dims(std::size_t n) const {
if (n >= N) {
throw std::out_of_range("Out of range in dims");
}
return dims_[n];
}
[[nodiscard]] const std::array<std::size_t, N>& strides() const {
return strides_;
}
[[nodiscard]] std::size_t strides(std::size_t n) const {
if (n >= N) {
throw std::out_of_range("Out of range in strides");
}
return strides_[n];
}
auto dataView() const {
return self().dataView();
}
auto origStrides() const {
return self().origStrides();
}
MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin);
MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end);
MatrixView<T, N - 1> row(std::size_t n);
MatrixView<T, N - 1> col(std::size_t n);
MatrixView<T, N - 1> operator[](std::size_t n) { return row(n); }
MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin) const;
MatrixView<T, N> submatrix(const std::array<std::size_t, N>& pos_begin, const std::array<std::size_t, N>& pos_end) const;
MatrixView<T, N - 1> row(std::size_t n) const;
MatrixView<T, N - 1> col(std::size_t n) const;
MatrixView<T, N - 1> operator[](std::size_t n) const { return row(n); }
friend std::ostream& operator<<(std::ostream& os, const MatrixBase& m) {
os << '{';
for (std::size_t i = 0; i != m.dims(0); ++i) {
os << m[i];
if (i + 1 != m.dims(0)) {
os << ", ";
}
}
return os << '}';
}
template <typename DerivedOther1, typename DerivedOther2,
std::semiregular U, std::semiregular V,
std::size_t N1, std::size_t N2,
std::invocable<MatrixView<T, N - 1>&,
const MatrixView<U, std::min(N1, N - 1)>&,
const MatrixView<V, std::min(N2, N - 1)>&> F>
requires (std::max(N1, N2) == N)
MatrixBase& applyFunctionWithBroadcast(const MatrixBase<DerivedOther1, U, N1>& m1,
const MatrixBase<DerivedOther2, V, N2>& m2,
F&& f);
};
template <typename Derived, std::semiregular T, std::size_t N>
MatrixBase<Derived, T, N>::MatrixBase(const std::array<std::size_t, N>& dims) : dims_ {dims} {
if (std::ranges::find(dims_, 0lu) != std::end(dims_)) {
throw std::invalid_argument("Zero dimension not allowed");
}
size_ = std::accumulate(std::begin(dims_), std::end(dims_), 1lu, std::multiplies<>{});
strides_ = computeStrides(dims_);
}
template <typename Derived, std::semiregular T, std::size_t N>
template <std::size_t M> requires (M < N)
MatrixBase<Derived, T, N>::MatrixBase(const std::array<std::size_t, M>& dims) : MatrixBase (prepend<N, M>(dims)) {}
template <typename Derived, std::semiregular T, std::size_t N>
template <IndexType... Dims>
MatrixBase<Derived, T, N>::MatrixBase(Dims... dims) : dims_{static_cast<std::size_t>(dims)...} {
static_assert(sizeof...(Dims) == N);
static_assert((std::is_integral_v<Dims> && ...));
if (std::ranges::find(dims_, 0lu) != std::end(dims_)) {
throw std::invalid_argument("Zero dimension not allowed");
}
size_ = std::accumulate(std::begin(dims_), std::end(dims_), 1lu, std::multiplies<>{});
strides_ = computeStrides(dims_);
}
template <typename Derived, std::semiregular T, std::size_t N>
template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T>
MatrixBase<Derived, T, N>::MatrixBase(const MatrixBase<DerivedOther, U, N>& other) : MatrixBase(other.dims()) {}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixBase<Derived, T, N>::MatrixBase(typename MatrixInitializer<T, N>::type init) : MatrixBase(deriveDims<N>(init)) {}
template <typename Derived, std::semiregular T, std::size_t N>
template <IndexType... Args>
typename MatrixBase<Derived, T, N>::reference MatrixBase<Derived, T, N>::operator()(Args... args) {
return const_cast<typename MatrixBase<Derived, T, N>::reference>(std::as_const(*this).operator()(args...));
}
template <typename Derived, std::semiregular T, std::size_t N>
template <IndexType... Args>
typename MatrixBase<Derived, T, N>::const_reference MatrixBase<Derived, T, N>::operator()(Args... args) const {
static_assert(sizeof...(args) == N);
std::array<std::size_t, N> pos {std::size_t(args)...};
return operator[](pos);
}
template <typename Derived, std::semiregular T, std::size_t N>
typename MatrixBase<Derived, T, N>::reference MatrixBase<Derived, T, N>::operator[](const std::array<std::size_t, N>& pos) {
return const_cast<typename MatrixBase<Derived, T, N>::reference>(std::as_const(*this).operator[](pos));
}
template <typename Derived, std::semiregular T, std::size_t N>
typename MatrixBase<Derived, T, N>::const_reference MatrixBase<Derived, T, N>::operator[](const std::array<std::size_t, N>& pos) const {
if (!std::equal(std::cbegin(pos), std::cend(pos), std::cbegin(dims_), std::less<>{})) {
throw std::out_of_range("Out of range in element access");
}
return *(cbegin() + std::inner_product(std::cbegin(pos), std::cend(pos), std::cbegin(strides_), 0lu));
}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin) {
return submatrix(pos_begin, dims_);
}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin) const {
return submatrix(pos_begin, dims_);
}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin,
const std::array<std::size_t, N>& pos_end) {
return std::as_const(*this).submatrix(pos_begin, pos_end);
}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N> MatrixBase<Derived, T, N>::submatrix(const std::array<std::size_t, N>& pos_begin,
const std::array<std::size_t, N>& pos_end) const {
if (!std::equal(std::cbegin(pos_begin), std::cend(pos_begin), std::cbegin(pos_end), std::less<>{})) {
throw std::out_of_range("submatrix begin/end position error");
}
std::array<std::size_t, N> view_dims;
std::transform(std::cbegin(pos_end), std::cend(pos_end), std::cbegin(pos_begin), std::begin(view_dims),
std::minus<>{});
MatrixView<T, N> view(view_dims, const_cast<T*>(&operator[](pos_begin)), strides());
return view;
}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N - 1> MatrixBase<Derived, T, N>::row(std::size_t n) {
return std::as_const(*this).row(n);
}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N - 1> MatrixBase<Derived, T, N>::row(std::size_t n) const {
const auto& orig_dims = dims();
if (n >= orig_dims[0]) {
throw std::out_of_range("row index error");
}
std::array<std::size_t, N - 1> row_dims;
std::copy(std::cbegin(orig_dims) + 1, std::cend(orig_dims), std::begin(row_dims));
std::array<std::size_t, N> pos_begin = {n, };
std::array<std::size_t, N - 1> row_strides;
std::array<std::size_t, N> orig_strides;
if constexpr (std::is_same_v<Derived, MatrixView<T, N>>) {
orig_strides = origStrides();
} else {
orig_strides = strides();
}
std::copy(std::cbegin(orig_strides) + 1, std::cend(orig_strides), std::begin(row_strides));
MatrixView<T, N - 1> nth_row(row_dims, const_cast<T*>(&operator[](pos_begin)), row_strides);
return nth_row;
}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N - 1> MatrixBase<Derived, T, N>::col(std::size_t n) {
return std::as_const(*this).col(n);
}
template <typename Derived, std::semiregular T, std::size_t N>
MatrixView<T, N - 1> MatrixBase<Derived, T, N>::col(std::size_t n) const {
const auto& orig_dims = dims();
if (n >= orig_dims[N - 1]) {
throw std::out_of_range("row index error");
}
std::array<std::size_t, N - 1> col_dims;
std::copy(std::cbegin(orig_dims), std::cend(orig_dims) - 1, std::begin(col_dims));
std::array<std::size_t, N> pos_begin = {0};
pos_begin[N - 1] = n;
std::array<std::size_t, N - 1> col_strides;
std::array<std::size_t, N> orig_strides;
if constexpr (std::is_same_v<Derived, MatrixView<T, N>>) {
orig_strides = origStrides();
} else {
orig_strides = strides();
}
std::copy(std::cbegin(orig_strides), std::cend(orig_strides) - 1, std::begin(col_strides));
MatrixView<T, N - 1> nth_col(col_dims, const_cast<T*>(&operator[](pos_begin)), col_strides);
return nth_col;
}
template <typename Derived, std::semiregular T, std::size_t N>
template <typename DerivedOther1, typename DerivedOther2,
std::semiregular U, std::semiregular V,
std::size_t N1, std::size_t N2,
std::invocable<MatrixView<T, N - 1>&,
const MatrixView<U, std::min(N1, N - 1)>&,
const MatrixView<V, std::min(N2, N - 1)>&> F>
requires (std::max(N1, N2) == N)
MatrixBase<Derived, T, N>& MatrixBase<Derived, T, N>::applyFunctionWithBroadcast(const MatrixBase<DerivedOther1, U, N1>& m1,
const MatrixBase<DerivedOther2, V, N2>& m2,
F&& f) {
if constexpr (N1 == N) {
if constexpr (N2 == N) {
auto r = dims(0);
auto r1 = m1.dims(0);
auto r2 = m2.dims(0);
if (r1 == r) {
if (r2 == r) {
for (std::size_t i = 0; i < r; ++i) {
auto row = this->row(i);
f(row, m1.row(i), m2.row(i));
}
} else { // r2 < r == r1
auto row2 = m2.row(0);
for (std::size_t i = 0; i < r; ++i) {
auto row = this->row(i);
f(row, m1.row(i), row2);
}
}
} else if (r2 == r) { // r1 < r == r2
auto row1 = m1.row(0);
for (std::size_t i = 0; i < r; ++i) {
auto row = this->row(i);
f(row, row1, m2.row(i));
}
} else {
assert(0); // cannot happen
}
} else { // N2 < N == N1
auto r = dims(0);
assert(r == m1.dims(0));
MatrixView<V, N2> view2 (m2);
for (std::size_t i = 0; i < r; ++i) {
auto row = this->row(i);
f(row, m1.row(i), view2);
}
}
} else if constexpr (N2 == N) { // N1 < N == N2
auto r = dims(0);
assert(r == m2.dims(0));
MatrixView<U, N1> view1 (m1);
for (std::size_t i = 0; i < r; ++i) {
auto row = this->row(i);
f(row, view1, m2.row(i));
}
} else {
assert(0); // cannot happen
}
return *this;
}
template <typename Derived, std::semiregular T>
class MatrixBase<Derived, T, 1> : public ObjectBase<MatrixBase<Derived, T, 1>> {
public:
static constexpr std::size_t ndim = 1;
private:
std::size_t dims_;
std::size_t strides_;
Derived& self() { return static_cast<Derived&>(*this); }
const Derived& self() const { return static_cast<const Derived&>(*this); }
public:
MatrixBase() = delete;
using Base = ObjectBase<MatrixBase<Derived, T, 1>>;
using Base::applyFunction;
using Base::operator=;
using Base::operator+=;
using Base::operator-=;
using Base::operator*=;
using Base::operator/=;
using Base::operator%=;
protected:
~MatrixBase() noexcept = default;
template <typename Dim> requires std::is_integral_v<Dim>
explicit MatrixBase(Dim dim) : dims_(dim), strides_(1) {};
template <typename DerivedOther, std::semiregular U> requires std::is_convertible_v<U, T>
MatrixBase(const MatrixBase<DerivedOther, U, 1>&);
MatrixBase(typename MatrixInitializer<T, 1>::type init);
public:
using value_type = T;
using reference = T&;
using const_reference = const T&;
using pointer = T*;
public:
friend void swap(MatrixBase& a, MatrixBase& b) noexcept {
std::swap(a.size_, b.size_);
std::swap(a.dims_, b.dims_);
std::swap(a.strides_, b.strides_);
}
auto begin() { return self().begin(); }
auto begin() const { return self().begin(); }
auto cbegin() const { return self().cbegin(); }
auto end() { return self().end(); }
auto end() const { return self().end(); }
auto cend() const { return self().cend(); }
auto rbegin() { return self().rbegin(); }
auto rbegin() const { return self().rbegin(); }
auto crbegin() const { return self().crbegin(); }
auto rend() { return self().rend(); }
auto rend() const { return self().rend(); }
auto crend() const { return self().crend(); }
template <typename Dim> requires std::is_integral_v<Dim>
reference operator()(Dim dim) {
return operator[](dim);
}
template <typename Dim> requires std::is_integral_v<Dim>
const_reference operator()(Dim dim) const {
return operator[](dim);
}
[[nodiscard]] std::array<std::size_t, 1> dims() const {
return {dims_};
}
[[nodiscard]] std::size_t dims(std::size_t n) const {
if (n >= 1) {
throw std::out_of_range("Out of range in dims");
}
return dims_;
}
[[nodiscard]] std::size_t strides() const {
return strides_;
}
auto dataView() const {
return self().dataView();
}
auto origStrides() const {
return self().origStrides();
}
MatrixView<T, 1> submatrix(std::size_t pos_begin);
MatrixView<T, 1> submatrix(std::size_t pos_begin, std::size_t pos_end);
T& row(std::size_t n);
T& col(std::size_t n);
T& operator[](std::size_t n) { return *(begin() + n); }
MatrixView<T, 1> submatrix(std::size_t pos_begin) const;
MatrixView<T, 1> submatrix(std::size_t pos_begin, std::size_t pos_end) const;
const T& row(std::size_t n) const;
const T& col(std::size_t n) const;
const T& operator[](std::size_t n) const { return *(cbegin() + n); }
friend std::ostream& operator<<(std::ostream& os, const MatrixBase& m) {
os << '{';
for (std::size_t i = 0; i != m.dims_; ++i) {
os << m[i];
if (i + 1 != m.dims_) {
os << ", ";
}
}
return os << '}';
}
template <typename DerivedOther1, typename DerivedOther2,
std::semiregular U, std::semiregular V,
std::invocable<T&, const U&, const V&> F>
MatrixBase& applyFunctionWithBroadcast(const frozenca::MatrixBase<DerivedOther1, U, 1>& m1,
const frozenca::MatrixBase<DerivedOther2, V, 1>& m2,
F&& f);
};
template <typename Derived, std::semiregular T>
MatrixBase<Derived, T, 1>::MatrixBase(typename MatrixInitializer<T, 1>::type init) : MatrixBase(deriveDims<1>(init)[0]) {
}
template <typename Derived, std::semiregular T>
MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin) {
return submatrix(pos_begin, dims_);
}
template <typename Derived, std::semiregular T>
MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin) const {
return submatrix(pos_begin, dims_);
}
template <typename Derived, std::semiregular T>
MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin,
std::size_t pos_end) {
return std::as_const(*this).submatrix(pos_begin, pos_end);
}
template <typename Derived, std::semiregular T>
MatrixView<T, 1> MatrixBase<Derived, T, 1>::submatrix(std::size_t pos_begin,
std::size_t pos_end) const {
if (pos_begin >= pos_end) {
throw std::out_of_range("submatrix begin/end position error");
}
MatrixView<T, 1> view ({pos_end - pos_begin}, const_cast<T*>(&operator[](pos_begin)), {strides_});
return view;
}
template <typename Derived, std::semiregular T>
T& MatrixBase<Derived, T, 1>::row(std::size_t n) {
return const_cast<T&>(std::as_const(*this).row(n));
}
template <typename Derived, std::semiregular T>
const T& MatrixBase<Derived, T, 1>::row(std::size_t n) const {
if (n >= dims_) {
throw std::out_of_range("row index error");
}
const T& val = operator[](n);
return val;
}
template <typename Derived, std::semiregular T>
T& MatrixBase<Derived, T, 1>::col(std::size_t n) {
return row(n);
}
template <typename Derived, std::semiregular T>
const T& MatrixBase<Derived, T, 1>::col(std::size_t n) const {
return row(n);
}
template <typename Derived, std::semiregular T>
template <typename DerivedOther1, typename DerivedOther2,
std::semiregular U, std::semiregular V,
std::invocable<T&, const U&, const V&> F>
MatrixBase<Derived, T, 1>& MatrixBase<Derived, T, 1>::applyFunctionWithBroadcast(
const frozenca::MatrixBase<DerivedOther1, U, 1>& m1,
const frozenca::MatrixBase<DerivedOther2, V, 1>& m2,
F&& f) {
// real update is done here by passing lvalue reference T&
auto r = dims(0);
auto r1 = m1.dims(0);
auto r2 = m2.dims(0);
if (r1 == r) {
if (r2 == r) {
for (std::size_t i = 0; i < r; ++i) {
f(this->row(i), m1.row(i), m2.row(i));
}
} else { // r2 < r == r1
auto row2 = m2.row(0);
for (std::size_t i = 0; i < r; ++i) {
f(this->row(i), m1.row(i), row2);
}
}
} else if (r2 == r) { // r1 < r == r2
auto row1 = m1.row(0);
for (std::size_t i = 0; i < r; ++i) {
f(this->row(i), row1, m2.row(i));
}
}
return *this;
}
} // namespace frozenca
#endif //FROZENCA_MATRIXBASE_H
(Stackexchange говорит, что OP слишком длинный, поэтому я заменяю два файла как ссылки)
MatrixImpl.h (https://github.com/frozenca/Ndim-Matrix/blob/main/MatrixImpl.h)
MatrixView.h (https://github.com/frozenca/Ndim-Matrix/blob/main/MatrixView.h)
MatrixUtils.h
#ifndef FROZENCA_MATRIXUTILS_H
#define FROZENCA_MATRIXUTILS_H
#include <algorithm>
#include <array>
#include <cassert>
#include <concepts>
#include <cstddef>
#include <initializer_list>
#include <iostream>
#include <iterator>
#include <memory>
#include <stdexcept>
#include <type_traits>
namespace frozenca {
template <std::semiregular T, std::size_t N>
class Matrix;
template <std::semiregular T, std::size_t N>
class MatrixView;
template <typename Derived, std::semiregular T, std::size_t N>
class MatrixBase;
template <typename Derived>
class ObjectBase;
template <typename T>
constexpr bool NotMatrix = true;
template <std::semiregular T, std::size_t N>
constexpr bool NotMatrix<Matrix<T, N>> = false;
template <std::semiregular T, std::size_t N>
constexpr bool NotMatrix<MatrixView<T, N>> = false;
template <typename Derived, std::semiregular T, std::size_t N>
constexpr bool NotMatrix<MatrixBase<Derived, T, N>> = false;
template <typename Derived>
constexpr bool NotMatrix<ObjectBase<Derived>> = false;
template <typename T>
concept isNotMatrix = NotMatrix<T> && std::semiregular<T>;
template <typename T>
concept isMatrix = !NotMatrix<T>;
template <typename T>
concept OneExists = requires () {
{ T{0} } -> std::convertible_to<T>;
{ T{1} } -> std::convertible_to<T>;
};
template <typename A, typename B>
concept WeakAddable = requires (A a, B b) {
a + b;
};
template <typename A, typename B>
concept WeakSubtractable = requires (A a, B b) {
a - b;
};
template <typename A, typename B>
concept WeakMultipliable = requires (A a, B b) {
a * b;
};
template <typename A, typename B>
concept WeakDividable = requires (A a, B b) {
a / b;
};
template <typename A, typename B>
concept WeakRemaindable = requires (A a, B b) {
a / b;
a % b;
};
template <typename A, typename B, typename C>
concept AddableTo = requires (A a, B b) {
{ a + b } -> std::convertible_to<C>;
};
template <typename A, typename B, typename C>
concept SubtractableTo = requires (A a, B b) {
{ a - b } -> std::convertible_to<C>;
};
template <typename A, typename B, typename C>
concept MultipliableTo = requires (A a, B b) {
{ a * b } -> std::convertible_to<C>;
};
template <typename A, typename B, typename C>
concept DividableTo = requires (A a, B b) {
{ a / b } -> std::convertible_to<C>;
};
template <typename A, typename B, typename C>
concept RemaindableTo = requires (A a, B b) {
{ a / b } -> std::convertible_to<C>;
{ a % b } -> std::convertible_to<C>;
};
template <typename A, typename B, typename C>
concept BitMaskableTo = requires (A a, B b) {
{ a & b } -> std::convertible_to<C>;
{ a | b } -> std::convertible_to<C>;
{ a ^ b } -> std::convertible_to<C>;
{ a << b } -> std::convertible_to<C>;
{ a >> b } -> std::convertible_to<C>;
};
template <typename A, typename B>
concept Addable = AddableTo<A, B, A>;
template <typename A, typename B>
concept Subtractable = SubtractableTo<A, B, A>;
template <typename A, typename B>
concept Multipliable = MultipliableTo<A, B, A>;
template <typename A, typename B>
concept Dividable = DividableTo<A, B, A>;
template <typename A, typename B>
concept Remaindable = RemaindableTo<A, B, A>;
template <typename A, typename B>
concept BitMaskable = BitMaskableTo<A, B, A>;
template <typename A, typename B> requires WeakAddable<A, B>
inline decltype(auto) Plus(A a, B b) {
return a + b;
}
template <typename A, typename B> requires WeakSubtractable<A, B>
inline decltype(auto) Minus(A a, B b) {
return a - b;
}
template <typename A, typename B> requires WeakMultipliable<A, B>
inline decltype(auto) Multiplies(A a, B b) {
return a * b;
}
template <typename A, typename B> requires WeakDividable<A, B>
inline decltype(auto) Divides(A a, B b) {
return a / b;
}
template <typename A, typename B> requires WeakRemaindable<A, B>
inline decltype(auto) Modulus(A a, B b) {
return a % b;
}
template <typename A, typename B>
using AddType = std::invoke_result_t<decltype(Plus<A, B>), A, B>;
template <typename A, typename B>
using SubType = std::invoke_result_t<decltype(Minus<A, B>), A, B>;
template <typename A, typename B>
using MulType = std::invoke_result_t<decltype(Multiplies<A, B>), A, B>;
template <typename A, typename B>
using DivType = std::invoke_result_t<decltype(Divides<A, B>), A, B>;
template <typename A, typename B>
using ModType = std::invoke_result_t<decltype(Modulus<A, B>), A, B>;
template <typename A, typename B>
concept DotProductable = Addable<MulType<A, B>, MulType<A, B>>;
template <typename A, typename B, typename C>
concept DotProductableTo = DotProductable<A, B> && MultipliableTo<A, B, C> && Addable<C, C>;
template <typename A, typename B, typename C> requires AddableTo<A, B, C>
inline void PlusTo(C& c, const A& a, const B& b) {
c = a + b;
}
template <typename A, typename B, typename C> requires SubtractableTo<A, B, C>
inline void MinusTo(C& c, const A& a, const B& b) {
c = a - b;
}
template <typename A, typename B, typename C> requires MultipliableTo<A, B, C>
inline void MultipliesTo(C& c, const A& a, const B& b) {
c = a * b;
}
template <typename A, typename B, typename C> requires DividableTo<A, B, C>
inline void DividesTo(C& c, const A& a, const B& b) {
c = a / b;
}
template <typename A, typename B, typename C> requires RemaindableTo<A, B, C>
inline void ModulusTo(C& c, const A& a, const B& b) {
c = a % b;
}
template <typename... Args>
inline constexpr bool All(Args... args) { return (... && args); };
template <typename... Args>
inline constexpr bool Some(Args... args) { return (... || args); };
template <std::size_t M, std::size_t N> requires (N < M)
std::array<std::size_t, M> prependDims(const std::array<std::size_t, N>& arr) {
std::array<std::size_t, M> dims;
std::ranges::fill(dims, 1u);
std::ranges::copy(arr, std::begin(dims) + (M - N));
return dims;
}
template <std::size_t M, std::size_t N>
bool bidirBroadcastable(const std::array<std::size_t, M>& sz1,
const std::array<std::size_t, N>& sz2) {
if constexpr (M == N) {
return (std::ranges::equal(sz1, sz2, [](const auto& d1, const auto& d2) {
return (d1 == d2) || (d1 == 1) || (d2 == 1);}));
} else if constexpr (M < N) {
return bidirBroadcastable(prependDims<N, M>(sz1), sz2);
} else {
static_assert(M > N);
return bidirBroadcastable(sz1, prependDims<M, N>(sz2));
}
}
template <std::size_t M, std::size_t N>
std::array<std::size_t, std::max(M, N)> bidirBroadcastedDims(const std::array<std::size_t, M>& sz1,
const std::array<std::size_t, N>& sz2) {
if constexpr (M == N) {
if (!bidirBroadcastable(sz1, sz2)) {
throw std::invalid_argument("Cannot broadcast");
}
std::array<std::size_t, M> sz;
std::ranges::transform(sz1, sz2, std::begin(sz), [](const auto& d1, const auto& d2) {
return std::max(d1, d2);
});
return sz;
} else if constexpr (M < N) {
return bidirBroadcastedDims(prependDims<N, M>(sz1), sz2);
} else {
static_assert(M > N);
return bidirBroadcastedDims(sz1, prependDims<M, N>(sz2));
}
}
template <std::size_t M> requires (M > 1)
std::array<std::size_t, M - 1> dotDims(const std::array<std::size_t, M>& sz1,
const std::array<std::size_t, 1>& sz2) {
if (sz1[M - 1] != sz2[0]) {
throw std::invalid_argument("Cannot do dot product, shape is not aligned");
}
std::array<std::size_t, M - 1> sz;
std::copy(std::begin(sz1), std::begin(sz1) + (M - 1), std::begin(sz));
return sz;
}
template <std::size_t M, std::size_t N> requires (N > 1)
std::array<std::size_t, M + N - 2> dotDims(const std::array<std::size_t, M>& sz1,
const std::array<std::size_t, N>& sz2) {
if (sz1[M - 1] != sz2[N - 2]) {
throw std::invalid_argument("Cannot do dot product, shape is not aligned");
}
std::array<std::size_t, M + N - 2> sz;
std::copy(std::begin(sz1), std::begin(sz1) + (M - 1), std::begin(sz));
std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(sz) + (M - 1));
std::copy(std::begin(sz2) + (N - 1), std::end(sz2), std::begin(sz) + (M + N - 3));
return sz;
}
template <std::size_t M, std::size_t N>
std::array<std::size_t, std::max(M, N)> matmulDims(const std::array<std::size_t, M>& sz1,
const std::array<std::size_t, N>& sz2) {
if constexpr (M == 1) {
std::array<std::size_t, 2> sz1_ = {1, sz1[0]};
return matmulDims(sz1_, sz2);
} else if constexpr (N == 1) {
std::array<std::size_t, 2> sz2_ = {sz2[0], 1};
return matmulDims(sz1, sz2_);
}
assert(M >= 2 && N >= 2);
if (sz1[M - 1] != sz2[N - 2]) {
throw std::invalid_argument("Cannot do dot product, shape is not aligned");
}
std::array<std::size_t, 2> last_sz = {sz1[M - 2], sz2[N - 1]};
if constexpr (M == 2) {
if constexpr (N == 2) {
return last_sz;
} else { // M = 2, N > 2
std::array<std::size_t, N> res_sz;
std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(res_sz));
std::copy(std::begin(last_sz), std::end(last_sz), std::begin(res_sz) + (N - 2));
return res_sz;
}
} else if constexpr (N == 2) { // M > 2, N = 2
std::array<std::size_t, M> res_sz;
std::copy(std::begin(sz1), std::begin(sz2) + (M - 2), std::begin(res_sz));
std::copy(std::begin(last_sz), std::end(last_sz), std::begin(res_sz) + (M - 2));
return res_sz;
} else { // M > 2, N > 2
std::array<std::size_t, std::max(M, N)> res_sz;
std::array<std::size_t, M - 2> sz1_front;
std::array<std::size_t, N - 2> sz2_front;
std::copy(std::begin(sz1), std::begin(sz1) + (M - 2), std::begin(sz1_front));
std::copy(std::begin(sz2), std::begin(sz2) + (N - 2), std::begin(sz2_front));
auto common_sz = bidirBroadcastedDims(sz1_front, sz2_front);
std::copy(std::begin(common_sz), std::end(common_sz), std::begin(res_sz));
std::copy(std::begin(last_sz), std::end(last_sz), std::end(res_sz) - 2);
return res_sz;
}
}
template <typename... Args>
concept IndexType = All(std::is_integral_v<Args>...);
template <std::size_t N>
std::array<std::size_t, N> computeStrides(const std::array<std::size_t, N>& dims) {
std::array<std::size_t, N> strides;
std::size_t str = 1;
for (std::size_t i = N - 1; i < N; --i) {
strides[i] = str;
str *= dims[i];
}
return strides;
}
template <std::size_t N, typename Initializer>
bool checkNonJagged(const Initializer& init) {
auto i = std::cbegin(init);
for (auto j = std::next(i); j != std::cend(init); ++j) {
if (i->size() != j->size()) {
return false;
}
}
return true;
}
template <std::size_t N, typename Iter, typename Initializer>
void addDims(Iter& first, const Initializer& init) {
if constexpr (N > 1) {
if (!checkNonJagged<N>(init)) {
throw std::invalid_argument("Jagged matrix initializer");
}
}
*first = std::size(init);
++first;
if constexpr (N > 1) {
addDims<N - 1>(first, *std::begin(init));
}
}
template <std::size_t N, typename Initializer>
std::array<std::size_t, N> deriveDims(const Initializer& init) {
std::array<std::size_t, N> dims;
auto f = std::begin(dims);
addDims<N>(f, init);
return dims;
}
template <std::semiregular T>
void addList(std::unique_ptr<T[]>& data,
const T* first, const T* last,
std::size_t& index) {
for (; first != last; ++first) {
data[index] = *first;
++index;
}
}
template <std::semiregular T, typename I>
void addList(std::unique_ptr<T[]>& data,
const std::initializer_list<I>* first, const std::initializer_list<I>* last,
std::size_t& index) {
for (; first != last; ++first) {
addList(data, first->begin(), first->end(), index);
}
}
template <std::semiregular T, typename I>
void insertFlat(std::unique_ptr<T[]>& data, std::initializer_list<I> list) {
std::size_t index = 0;
addList(data, std::begin(list), std::end(list), index);
}
inline long quot(long a, long b) {
return (a / b) - (a % b < 0);
}
inline long mod(long a, long b) {
return (a % b + b) % b;
}
} // namespace frozenca
#endif //FROZENCA_MATRIXUTILS_H
MatrixInitializer.h
#ifndef FROZENCA_MATRIXINITIALIZER_H
#define FROZENCA_MATRIXINITIALIZER_H
#include <cstddef>
#include <concepts>
#include <initializer_list>
namespace frozenca {
template <std::semiregular T, std::size_t N>
struct MatrixInitializer {
using type = std::initializer_list<typename MatrixInitializer<T, N - 1>::type>;
};
template <std::semiregular T>
struct MatrixInitializer<T, 1> {
using type = std::initializer_list<T>;
};
template <std::semiregular T>
struct MatrixInitializer<T, 0>;
} // namespace frozenca
#endif //FROZENCA_MATRIXINITIALIZER_H
MatrixOps.h
#ifndef FROZENCA_MATRIXOPS_H
#define FROZENCA_MATRIXOPS_H
#include "MatrixImpl.h"
namespace frozenca {
// Matrix constructs
template <std::semiregular T, std::size_t N>
Matrix<T, N> empty(const std::array<std::size_t, N>& arr) {
Matrix<T, N> mat (arr);
return mat;
}
template <typename Derived, std::semiregular T, std::size_t N>
Matrix<T, N> empty_like(const MatrixBase<Derived, T, N>& base) {
Matrix<T, N> mat (base.dims());
return mat;
}
template <OneExists T>
Matrix<T, 2> eye(std::size_t n, std::size_t m) {
Matrix<T, 2> mat (n, m);
for (std::size_t i = 0; i < std::min(n, m); ++i) {
mat(i, i) = T{1};
}
return mat;
}
template <OneExists T>
Matrix<T, 2> eye(std::size_t n) {
return eye<T>(n, n);
}
template <OneExists T>
Matrix<T, 2> identity(std::size_t n) {
return eye<T>(n, n);
}
template <OneExists T, std::size_t N>
Matrix<T, N> ones(const std::array<std::size_t, N>& arr) {
Matrix<T, N> mat (arr);
std::ranges::fill(mat, T{1});
return mat;
}
template <typename Derived, OneExists T, std::size_t N>
Matrix<T, N> ones_like(const MatrixBase<Derived, T, N>& base) {
Matrix<T, N> mat (base.dims());
std::ranges::fill(mat, T{1});
return mat;
}
template <std::semiregular T, std::size_t N>
Matrix<T, N> zeros(const std::array<std::size_t, N>& arr) {
Matrix<T, N> mat (arr);
std::ranges::fill(mat, T{0});
return mat;
}
template <typename Derived, std::semiregular T, std::size_t N>
Matrix<T, N> zeros_like(const MatrixBase<Derived, T, N>& base) {
Matrix<T, N> mat (base.dims());
std::ranges::fill(mat, T{0});
return mat;
}
template <std::semiregular T, std::size_t N>
Matrix<T, N> full(const std::array<std::size_t, N>& arr, const T& fill_value) {
Matrix<T, N> mat (arr);
std::ranges::fill(mat, fill_value);
return mat;
}
template <typename Derived, std::semiregular T, std::size_t N>
Matrix<T, N> full_like(const MatrixBase<Derived, T, N>& base, const T& fill_value) {
Matrix<T, N> mat (base.dims());
std::ranges::fill(mat, fill_value);
return mat;
}
// binary matrix operators
namespace {
template <std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2, std::size_t N>
requires AddableTo<U, V, T> && (std::max(N1, N2) == N)
void AddTo(MatrixView<T, N>& m,
const MatrixView<U, N1>& m1,
const MatrixView<V, N2>& m2) {
if constexpr (N == 1) {
m.applyFunctionWithBroadcast(m1, m2, PlusTo<U, V, T>);
} else {
m.applyFunctionWithBroadcast(m1, m2, AddTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
}
}
template <std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2, std::size_t N>
requires SubtractableTo<U, V, T> && (std::max(N1, N2) == N)
void SubtractTo(MatrixView<T, N>& m,
const MatrixView<U, N1>& m1,
const MatrixView<V, N2>& m2) {
if constexpr (N == 1) {
m.applyFunctionWithBroadcast(m1, m2, MinusTo<U, V, T>);
} else {
m.applyFunctionWithBroadcast(m1, m2, SubtractTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
}
}
template <std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2, std::size_t N>
requires MultipliableTo<U, V, T> && (std::max(N1, N2) == N)
void MultiplyTo(MatrixView<T, N>& m,
const MatrixView<U, N1>& m1,
const MatrixView<V, N2>& m2) {
if constexpr (N == 1) {
m.applyFunctionWithBroadcast(m1, m2, MultipliesTo<U, V, T>);
} else {
m.applyFunctionWithBroadcast(m1, m2, MultiplyTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
}
}
template <std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2, std::size_t N>
requires DividableTo<U, V, T> && (std::max(N1, N2) == N)
void DivideTo(MatrixView<T, N>& m,
const MatrixView<U, N1>& m1,
const MatrixView<V, N2>& m2) {
if constexpr (N == 1) {
m.applyFunctionWithBroadcast(m1, m2, DividesTo<U, V, T>);
} else {
m.applyFunctionWithBroadcast(m1, m2, DivideTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
}
}
template <std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2, std::size_t N>
requires RemaindableTo<U, V, T> && (std::max(N1, N2) == N)
void ModuloTo(MatrixView<T, N>& m,
const MatrixView<U, N1>& m1,
const MatrixView<V, N2>& m2) {
if constexpr (N == 1) {
m.applyFunctionWithBroadcast(m1, m2, ModulusTo<U, V, T>);
} else {
m.applyFunctionWithBroadcast(m1, m2, ModuloTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
}
}
} // anonymous namespace
template <typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V,
std::size_t N1, std::size_t N2,
std::semiregular T = AddType<U, V>> requires AddableTo<U, V, T>
decltype(auto) operator+ (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
constexpr std::size_t N = std::max(N1, N2);
auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
Matrix<T, N> res = zeros<T, N>(dims);
res.applyFunctionWithBroadcast(m1, m2, AddTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
return res;
}
template <typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V,
std::size_t N1, std::size_t N2,
std::semiregular T = SubType<U, V>> requires SubtractableTo<U, V, T>
decltype(auto) operator- (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
constexpr std::size_t N = std::max(N1, N2);
auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
Matrix<T, N> res = zeros<T, N>(dims);
res.applyFunctionWithBroadcast(m1, m2, SubtractTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
return res;
}
template <typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V,
std::size_t N1, std::size_t N2,
std::semiregular T = MulType<U, V>> requires MultipliableTo<U, V, T>
decltype(auto) operator* (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
constexpr std::size_t N = std::max(N1, N2);
auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
Matrix<T, N> res = zeros<T, N>(dims);
res.applyFunctionWithBroadcast(m1, m2, MultiplyTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
return res;
}
template <typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V,
std::size_t N1, std::size_t N2,
std::semiregular T = DivType<U, V>> requires DividableTo<U, V, T>
decltype(auto) operator/ (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
constexpr std::size_t N = std::max(N1, N2);
auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
Matrix<T, N> res = zeros<T, N>(dims);
res.applyFunctionWithBroadcast(m1, m2, DivideTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
return res;
}
template <typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V,
std::size_t N1, std::size_t N2,
std::semiregular T = ModType<U, V>> requires RemaindableTo<U, V, T>
decltype(auto) operator% (const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
constexpr std::size_t N = std::max(N1, N2);
auto dims = bidirBroadcastedDims(m1.dims(), m2.dims());
Matrix<T, N> res = zeros<T, N>(dims);
res.applyFunctionWithBroadcast(m1, m2, ModuloTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
return res;
}
} // namespace frozenca
#endif //FROZENCA_MATRIXOPS_H
LinalgOps.h (ОПЕРАЦИИ ЛИНЕЙНОЙ АЛГЕБРЫ)
#ifndef FROZENCA_LINALGOPS_H
#define FROZENCA_LINALGOPS_H
#include "Matrix.h"
namespace frozenca {
namespace {
template <std::semiregular U, std::semiregular V, std::semiregular T>
requires DotProductableTo<U, V, T>
void DotTo(T& m,
const MatrixView<U, 1>& m1,
const MatrixView<V, 1>& m2) {
m += std::inner_product(std::begin(m1), std::end(m1), std::begin(m2), T{0});
}
template <std::semiregular U, std::semiregular V, std::semiregular T>
requires DotProductableTo<U, V, T>
void DotTo(MatrixView<T, 1>& m,
const MatrixView<U, 1>& m1,
const MatrixView<V, 2>& m2) {
assert(m.dims(0) == m2.dims(1));
std::size_t c = m2.dims(1);
for (std::size_t j = 0; j < c; ++j) {
auto col2 = m2.col(j);
m[j] += std::inner_product(std::begin(m1), std::end(m1), std::begin(col2), T{0});
}
}
template <std::semiregular U, std::semiregular V, std::semiregular T, std::size_t N2>
requires DotProductableTo<U, V, T> && (N2 > 2)
void DotTo(MatrixView<T, N2 - 1>& m,
const MatrixView<U, 1>& m1,
const MatrixView<V, N2>& m2) {
assert(m.dims(0) == m2.dims(0));
std::size_t r = m.dims(0);
for (std::size_t i = 0; i < r; ++i) {
auto row0 = m.row(i);
auto row2 = m2.row(i);
DotTo(row0, m1, row2);
}
}
template <std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2>
requires DotProductableTo<U, V, T> && (N1 > 1)
void DotTo(MatrixView<T, N1 - 1>& m,
const MatrixView<U, N1>& m1,
const MatrixView<V, 1>& m2) {
assert(m.dims(0) == m1.dims(0));
std::size_t r = m.dims(0);
for (std::size_t i = 0; i < r; ++i) {
auto row0 = m.row(i);
auto row1 = m1.row(i);
DotTo(row0, row1, m2);
}
}
template <std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2>
requires DotProductableTo<U, V, T> && (N1 > 1) && (N2 > 1)
void DotTo(MatrixView<T, N1 + N2 - 2>& m,
const MatrixView<U, N1>& m1,
const MatrixView<V, N2>& m2) {
assert(m.dims(0) == m1.dims(0));
std::size_t r = m.dims(0);
for (std::size_t i = 0; i < r; ++i) {
auto row0 = m.row(i);
auto row1 = m1.row(i);
DotTo(row0, row1, m2);
}
}
template <typename Derived0, typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2>
requires DotProductableTo<U, V, T>
void DotTo(MatrixBase<Derived0, T, (N1 + N2 - 2)>& m,
const MatrixBase<Derived1, U, N1>& m1,
const MatrixBase<Derived2, V, N2>& m2) {
MatrixView<T, (N1 + N2 - 2)> m_view (m);
MatrixView<U, N1> m1_view (m1);
MatrixView<V, N2> m2_view (m2);
DotTo(m_view, m1_view, m2_view);
}
template <std::semiregular U, std::semiregular V, std::semiregular T,
std::size_t N1, std::size_t N2, std::size_t N>
requires DotProductableTo<U, V, T> && (std::max(N1, N2) == N)
void MatmulTo(MatrixView<T, N>& m,
const MatrixView<U, N1>& m1,
const MatrixView<V, N2>& m2) {
if constexpr (N == 2) {
DotTo(m, m1, m2);
} else {
m.applyFunctionWithBroadcast(m1, m2, MatmulTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
}
}
} // anonymous namespace
template <typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V,
std::size_t M, std::size_t N,
std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T>
decltype(auto) dot(const MatrixBase<Derived1, U, M>& m1, const MatrixBase<Derived2, V, N>& m2) {
auto dims = dotDims(m1.dims(), m2.dims());
Matrix<T, (M + N - 2)> res = zeros<T, (M + N - 2)>(dims);
DotTo(res, m1, m2);
return res;
}
template <typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V,
std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T>
decltype(auto) dot(const MatrixBase<Derived1, U, 1>& m1, const MatrixBase<Derived2, V, 1>& m2) {
auto dims = dotDims(m1.dims(), m2.dims());
T res {0};
DotTo(res, m1, m2);
return res;
}
template <typename Derived1, typename Derived2,
std::semiregular U, std::semiregular V,
std::size_t N1, std::size_t N2,
std::semiregular T = MulType<U, V>> requires DotProductableTo<U, V, T>
decltype(auto) matmul(const MatrixBase<Derived1, U, N1>& m1, const MatrixBase<Derived2, V, N2>& m2) {
constexpr std::size_t N = std::max(N1, N2);
auto dims = matmulDims(m1.dims(), m2.dims());
Matrix<T, N> res = zeros<T, N>(dims);
res.applyFunctionWithBroadcast(m1, m2, MatmulTo<U, V, T, std::min(N1, N - 1), std::min(N2, N - 1), N - 1>);
return res;
}
} // namespace frozenca
#endif //FROZENCA_LINALGOPS_H
Matrix.h
#ifndef FROZENCA_MATRIX_H
#define FROZENCA_MATRIX_H
#include "MatrixImpl.h"
#include "MatrixOps.h"
#include "LinalgOps.h"
#endif //FROZENCA_MATRIX_H
Не стесняйтесь комментировать что угодно!
Что мне не нравится:
- Слишком много API для пользователей. Я ненавижу C ++
include
s, очень жду модулей C ++ 20! (Текущая реализация модуля в MSVC 19.28 не работает, не работает в моем коде) - .reshape (). Он должен переместить буфер, поэтому он принимает ссылку на rvalue, но мне интересно что-то получше.
- Вспомогательные функции становятся слишком загруженными шаблонами. Точнее, мне не нравятся такие функции, как
applyFunctionWithBroadcasting
генерируются по-разному с каждым разным N> 1, даже если то, что они делают, одинаково. Но мне нужна специализация для N = 1 (потому что подматрица, строка, столбец должны возвращатьT&
), поэтому я не знаю лучшего способа.
Чтобы увидеть, что происходит с «широковещанием», скалярным произведением и матричным умножением, см. Ниже:
https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
https://numpy.org/doc/stable/reference/generated/numpy.dot.html
https://numpy.org/doc/stable/reference/generated/numpy.matmul.html