Динамически индексируемый массив numpy

Я хочу создать функцию, которая принимает массив numpy, ось и индекс этой оси и возвращает массив с фиксированным индексом на указанной оси. Я думал создать строку, которая динамически изменяется, а затем оценивается как индексный срез в массиве (как показано на этот ответ). Я придумал такую ​​функцию:

import numpy as np

def select_slc_ax(arr, slc, axs):

    dim = len(arr.shape)-1
    slc = str(slc)
    slice_str = ":,"*axs+slc+",:"*(dim-axs)
    print(slice_str)
    slice_obj = eval(f'np.s_[{slice_str}]')
    return arr[slice_obj]

Пример

>>> arr = np.array([[[0, 0, 0],
                     [0, 0, 0],
                     [0, 0, 0]],
                    [[0, 1, 0],
                     [1, 1, 1],
                     [0, 1, 0]],
                    [[0, 0, 0],
                     [0, 0, 0], 
                     [0, 0, 0]]], dtype="uint8")

>>> select_slc_ax(arr, 2, 1)
:,2,:
array([[0, 0, 0],
       [0, 1, 0],
       [0, 0, 0]], dtype=uint8)

Мне было интересно, есть ли лучший способ сделать это.

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *