Оптимизация матричного умножения векторов с помощью ключевого слова «регистр» и арифметики с небезопасным указателем

Я знаю, что этот фрагмент кода довольно странный, но он очень хорошо выполняет свою работу с точки зрения производительности, сокращая время выполнения операции с очень интенсивными вычислениями в 3-5 раз без использования лучшего алгоритма. Но так как код, используя ключевое слово controversal register а я использую арифметику с указателями довольно небезопасно (на мой взгляд), хотелось бы услышать мнение других людей.

Немного о требованиях: поскольку проект представляет собой исследовательский проект с упором на простоту, мы не хотим использовать библиотеку линейной алгебры вроде Eigen3 или же Blas или вообще какая-то стандартная библиотека C ++. Он также не должен использовать какой-либо другой алгоритм, кроме школьного метода умножения векторов матриц, потому что мы хотели бы сосредоточиться на аппаратном обеспечении.

У меня вопрос: часто ли такой код встречается в природе и считается ли он нормальным? Как я могу дальше оптимизировать этот код?

Буду признателен за любой конструктивный отзыв.

/**
 * Multiply matrix with vector
 *
 * Matrix is transposed. Hence we can do row-wise inner product.
 *
 * @param matrix (n x m)
 * @param vector (1 x m)
 * @param output (1 x n)
 * @param input_height_ n
 * @param input_width_ m
 */
void matrix_vector_multiply(float *matrix, float *vector, float *output, uint32_t input_height_, uint32_t input_width_) {
    /**
     * The functional principle of this code block is very simple. We iterate 4 rows parallel.
     *
     * With this trick, we only have to fetch the vector's data once and effectively reuse the vector.
     *
     * We used the keyword register to give the compiler hint which variable we would love to keep on the CPU register.
     *
     * Since CPU registers are rare, we really want to use it where it needs to be. Since we want to put them all on registers, we will utilize only 4 rows at once.
     *
     * Also the register keyword is only a hint to the compiler, which can be completely ignored.
     */
    register uint32_t input_height = input_height_;
    register uint32_t input_width = input_width_;
    
    // Put the needed data into a registered variable
    //
    // We will obtain a higher chance for the compiler to optimize our code better
    register float * output_ptr = output;
    register float * input_ptr = matrix;

    /**
     * Using blocked data only if we have more than 4 rows, everything else would be
     * a waste of overhead
     */
    if(input_height > 4 && input_width > 4) {
        uint32_t y = 0;

        // Four at once
        for (; y < input_height - 4; y += 4) {
            // Since we iterate the vector_ptr manually for higher cache locality, we have to reset it every loop
            register float * vector_ptr = vector;

            // Load the data from matrix into four rows
            register float *input_cols_ptr1 = input_ptr;
            input_ptr += input_width;
            register float *input_cols_ptr2 = input_ptr;
            input_ptr += input_width;
            register float *input_cols_ptr3 = input_ptr;
            input_ptr += input_width;
            register float *input_cols_ptr4 = input_ptr;
            input_ptr += input_width;

            // Result for each row
            register float product0 = 0;
            register float product1 = 0;
            register float product2 = 0;
            register float product3 = 0;

            for (uint32_t x = 0; x < input_width; x++) {
                // Picking the value of the vector at the position
                register float vector_val = *vector_ptr++;


                product0 += vector_val * *input_cols_ptr1++;
                product1 += vector_val * *input_cols_ptr2++;
                product2 += vector_val * *input_cols_ptr3++;
                product3 += vector_val * *input_cols_ptr4++;
            }

            // Store the result
            *output_ptr++ += product0;
            *output_ptr++ += product1;
            *output_ptr++ += product2;
            *output_ptr++ += product3;
        }

        // Processing the rest columns
        for (; y < input_height; y++, output_ptr++) {
            register float * vector_ptr = vector;
            for (uint32_t x = 0; x < input_width; x++) {
                *output_ptr += *vector_ptr++ * *input_ptr++;
            }
        }

        /**
         * Everything else goes into this.
         */
    } else {
        for (register uint32_t y = 0; y < input_height; y++, output_ptr++) {
            register float * vector_ptr = vector;
            for (register uint32_t x = 0; x < input_width; x++) {
                *output_ptr += *vector_ptr++ * *input_ptr++;
            }
        }
    }
}

Исходный код выглядел так

void gemm(const float *matrix, const float *vector, float *output, uint32_t input_height, uint32_t input_width) {
#pragma omp parallel for
    for (uint32_t y = 0; y < input_height; y++) {
        float sum = 0.0f;
        const float * row = matrix + y * input_width;
        for (uint32_t x = 0; x < input_width; x++) {
            sum += vector[x] * row[x];
        }
        output[y] += sum;
    }
}

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *