Эффективная функция скользящего окна NumPy

Вот функция для создания скользящих окон из одномерного массива NumPy:

from math import ceil, floor
import numpy as np

def slide_window(A, win_size, stride, padding = None):
    '''Collects windows that slides over a one-dimensional array.

    If padding is None, the last (rightmost) window is dropped if it
    is incomplete, otherwise it is padded with the padding value.
    '''
    if win_size <= 0:
        raise ValueError('Window size must be positive.')
    if not (0 < stride <= win_size):
        raise ValueError(f'Stride must satisfy 0 < stride <= {win_size}.')
    if not A.base is None:
        raise ValueError('Views cannot be slided over!')

    n_elems = len(A)
    if padding is not None:
        n_windows = ceil(n_elems / stride)
        A = np.pad(A, (0, n_windows * win_size - n_elems),
                   constant_values = padding)
    else:
        n_windows = floor(n_elems / stride)
    shape = n_windows, win_size

    elem_size = A.strides[-1]
    return np.lib.stride_tricks.as_strided(
        A, shape = shape,
        strides = (elem_size * stride, elem_size),
        writeable = False)

(Код был обновлен на основе отзывов Марка) Предназначен для использования следующим образом:

>>> slide_window(np.arange(5), 3, 2, -1)
array([[ 0,  1,  2],
       [ 2,  3,  4],
       [ 4, -1, -1]])

Моя реализация верна? Можно ли сделать код более читабельным? В NumPy 1.20 есть функция под названием slide_window_view, но мой код должен работать со старыми версиями NumPy.

1 ответ
1

Несколько предложений:

  • Проверка ввода: нет проверки ввода для win_size и padding. Если win_size является -3 исключение говорит ValueError: Stride must satisfy 0 < stride <= -3.. Если padding является строкой, numpy выдает исключение.

  • Подсказки по типу: подумайте о добавлении набор текста чтобы предоставить вызывающему абоненту дополнительную информацию.

  • f-струны: в зависимости от используемой версии Python сообщение об исключении можно немного упростить.

    Из:

    if not (0 < stride <= win_size):
        fmt="Stride must satisfy 0 < stride <= %d."
        raise ValueError(fmt % win_size)
    

    К:

    if not 0 < stride <= win_size:
        raise ValueError(f'Stride must satisfy 0 < stride <= {win_size}.')
    
  • Дублирование: заявление shape = n_windows, win_size кажется дублированным и может быть упрощен. Из:

    if padding is not None:
        n_windows = ceil(n_elems / stride)
        shape = n_windows, win_size
        A = np.pad(A, (0, n_windows * win_size - n_elems),
                 constant_values = padding)
    else:
        n_windows = floor(n_elems / stride)
        shape = n_windows, win_size
    

    К:

    if padding is not None:
        n_windows = ceil(n_elems / stride)
        A = np.pad(A, (0, n_windows * win_size - n_elems),
                       constant_values=padding)
    else:
        n_windows = floor(n_elems / stride)
    shape = n_windows, win_size
    
  • Предупреждение: FYI на документ np.lib.stride_tricks.as_strided есть предупреждение, в котором говорится This function has to be used with extreme care, see notes.. Не уверен, что это применимо к вашему варианту использования, но подумайте о том, чтобы проверить это.

    Добавить комментарий

    Ваш адрес email не будет опубликован. Обязательные поля помечены *