Я знаю, что этот фрагмент кода довольно странный, но он очень хорошо выполняет свою работу с точки зрения производительности, сокращая время выполнения операции с очень интенсивными вычислениями в 3-5 раз без использования лучшего алгоритма. Но так как код, используя ключевое слово controversal register
а я использую арифметику с указателями довольно небезопасно (на мой взгляд), хотелось бы услышать мнение других людей.
Немного о требованиях: поскольку проект представляет собой исследовательский проект с упором на простоту, мы не хотим использовать библиотеку линейной алгебры вроде Eigen3
или же Blas
или вообще какая-то стандартная библиотека C ++. Он также не должен использовать какой-либо другой алгоритм, кроме школьного метода умножения векторов матриц, потому что мы хотели бы сосредоточиться на аппаратном обеспечении.
У меня вопрос: часто ли такой код встречается в природе и считается ли он нормальным? Как я могу дальше оптимизировать этот код?
Буду признателен за любой конструктивный отзыв.
/**
* Multiply matrix with vector
*
* Matrix is transposed. Hence we can do row-wise inner product.
*
* @param matrix (n x m)
* @param vector (1 x m)
* @param output (1 x n)
* @param input_height_ n
* @param input_width_ m
*/
void matrix_vector_multiply(float *matrix, float *vector, float *output, uint32_t input_height_, uint32_t input_width_) {
/**
* The functional principle of this code block is very simple. We iterate 4 rows parallel.
*
* With this trick, we only have to fetch the vector's data once and effectively reuse the vector.
*
* We used the keyword register to give the compiler hint which variable we would love to keep on the CPU register.
*
* Since CPU registers are rare, we really want to use it where it needs to be. Since we want to put them all on registers, we will utilize only 4 rows at once.
*
* Also the register keyword is only a hint to the compiler, which can be completely ignored.
*/
register uint32_t input_height = input_height_;
register uint32_t input_width = input_width_;
// Put the needed data into a registered variable
//
// We will obtain a higher chance for the compiler to optimize our code better
register float * output_ptr = output;
register float * input_ptr = matrix;
/**
* Using blocked data only if we have more than 4 rows, everything else would be
* a waste of overhead
*/
if(input_height > 4 && input_width > 4) {
uint32_t y = 0;
// Four at once
for (; y < input_height - 4; y += 4) {
// Since we iterate the vector_ptr manually for higher cache locality, we have to reset it every loop
register float * vector_ptr = vector;
// Load the data from matrix into four rows
register float *input_cols_ptr1 = input_ptr;
input_ptr += input_width;
register float *input_cols_ptr2 = input_ptr;
input_ptr += input_width;
register float *input_cols_ptr3 = input_ptr;
input_ptr += input_width;
register float *input_cols_ptr4 = input_ptr;
input_ptr += input_width;
// Result for each row
register float product0 = 0;
register float product1 = 0;
register float product2 = 0;
register float product3 = 0;
for (uint32_t x = 0; x < input_width; x++) {
// Picking the value of the vector at the position
register float vector_val = *vector_ptr++;
product0 += vector_val * *input_cols_ptr1++;
product1 += vector_val * *input_cols_ptr2++;
product2 += vector_val * *input_cols_ptr3++;
product3 += vector_val * *input_cols_ptr4++;
}
// Store the result
*output_ptr++ += product0;
*output_ptr++ += product1;
*output_ptr++ += product2;
*output_ptr++ += product3;
}
// Processing the rest columns
for (; y < input_height; y++, output_ptr++) {
register float * vector_ptr = vector;
for (uint32_t x = 0; x < input_width; x++) {
*output_ptr += *vector_ptr++ * *input_ptr++;
}
}
/**
* Everything else goes into this.
*/
} else {
for (register uint32_t y = 0; y < input_height; y++, output_ptr++) {
register float * vector_ptr = vector;
for (register uint32_t x = 0; x < input_width; x++) {
*output_ptr += *vector_ptr++ * *input_ptr++;
}
}
}
}
Исходный код выглядел так
void gemm(const float *matrix, const float *vector, float *output, uint32_t input_height, uint32_t input_width) {
#pragma omp parallel for
for (uint32_t y = 0; y < input_height; y++) {
float sum = 0.0f;
const float * row = matrix + y * input_width;
for (uint32_t x = 0; x < input_width; x++) {
sum += vector[x] * row[x];
}
output[y] += sum;
}
}