Более быстрое умножение трехдиагональной матрицы

У меня есть функция, которую я пытался сделать быстрее, которая вычисляет $ XA + B $, куда $ A in mathbb {R} ^ {n times n} $ — трехдиагональная матрица, а $ X, B in mathbb {R} ^ {m times n} $. В коде нижняя, основная и верхняя диагонали представлены dl, dm, du. Наивная реализация этого:

/* Compute BA + X where A is tridiagonal shape (n, n), B, X are matrices shape (m, n) */
void dgtaxmt(const double *dl, const double *dm, const double *du,
             const double *x, const double *b, double *out,
             const unsigned int n, const unsigned int m, double *temp)
{
    int i, j, ii;

    for (i = 0; i < m; i++){
        ii = i * n;
        temp[ii] = (x[ii] * dm[0]) + (x[ii + 1] * dl[0]) + b[ii];
        for (j = 1; j < n - 1; j++){
            temp[ii + j] = ((x[ii + j - 1] * du[j - 1]) + ((x[ii + j] * dm[j]) + ((x[ii + j + 1] * dl[j]) + b[ii + j])));
        }
        temp[ii + j] = (x[ii + j] * dm[j]) + (x[ii + j - 1] * du[j - 1]) + b[ii + j];
    }
    // store the transpose of temp into out
    // otrans(temp, out, m, n, 16);

}

Обратите внимание, что все указатели используют restrict ключевое слово, но я оставил их для удобства чтения. Мои попытки повысить производительность включают в себя мозаику и использование встроенных функций AVX в надежде, что не придется перезагружать dl, dm, du ускорит код, но на самом деле делает его намного медленнее. Есть ли способ эффективно векторизовать эту функцию? Clang не сообщает о векторизации, а генерируемый им ассемблерный код использует регистры xmm, но только первое значение. Может ли быть другой порядок операций, который позволил бы лучше повторно использовать данные, которые были ранее загружены?

ИЗМЕНИТЬ Вот попытка, которую я сделал с встроенными функциями AVX

void dgtaxmt_avx(const double *dl, const double *dm, const double *du,
                 const double *x, const double *b, double *out,
                 const unsigned int n, const unsigned int m, double *temp)
{
    // idx is for "rolling" values in vector to left by one index
    const __m256i idx = _mm256_set_epi32(1, 0, 7, 6, 5, 4, 3, 2);
    __m256d dl_vec, dm_vec, du_vec, bn, xnm1, xn, xn1, tmp4;
    __m128d tmp2;

    unsigned int i, j, ii;
    const unsigned int r = ((n - 2) & (-4)) + 1;

    for (i = 0; i < m; i++){
        ii = i * n;
        temp[ii] = (x[ii] * dm[0]) + (x[ii + 1] * dl[0]) + b[ii];
        for (j = 1; j < r; j += 4){
            dl_vec = _mm256_loadu_pd(&dl[j - 1]);
            dm_vec = _mm256_loadu_pd(&dm[j]);
            du_vec = _mm256_loadu_pd(&du[j]);
            bn     = _mm256_loadu_pd(&b[ii + j]);

            xnm1  = _mm256_loadu_pd(&x[ii + j - 1]);
            tmp2  = _mm_loadu_pd(&x[ii + j + 3]);
            xn1   = _mm256_set_m128d(tmp2, _mm256_extractf128_pd(xnm1, 1));

            // use permutations to avoid doing extra loads
            xn   = (__m256d)_mm256_permutevar8x32_ps((__m256)xnm1, idx);
            tmp4 = (__m256d)_mm256_permutevar8x32_ps((__m256)xn1,  idx);
            tmp2 = _mm256_extractf128_pd(tmp4, 0);
            xn   = _mm256_insertf128_pd(xn, tmp2, 1);

            tmp4 = _mm256_fmadd_pd(du_vec, xn1, bn);
            tmp4 = _mm256_fmadd_pd(dm_vec, xn, tmp4);
            tmp4 = _mm256_fmadd_pd(dl_vec, xnm1, tmp4);
            
            _mm256_storeu_pd(&out[ii + j], tmp4);
        }
        for (j = r; j < n - 1; j++){
            temp[ii + j] = ((x[ii + j - 1] * du[j - 1]) + ((x[ii + j] * dm[j]) + ((x[ii + j + 1] * dl[j]) + b[ii + j])));
        }
        temp[ii + j] = (x[ii + j] * dm[j]) + (x[ii + j - 1] * du[j - 1]) + b[ii + j];
    }
    
    otrans(temp, out, m, n, 16);
}

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *