Вот в чем проблема.
Вернуть копию x, где минимальное значение в каждой строке установлено на 0.
Например, если x:
x = torch.tensor([[ [10, 20, 30], [ 2, 5, 1] ]])
Тогда y = zero_row_min (x) должно быть:
torch.tensor([ [0, 20, 30], [2, 5, 0] ])
Ваша реализация должна использовать операции сокращения и индексации; вы не должны использовать никаких явных циклов. Входной тензор не следует изменять.
Входы:
- x: Тензор формы (M, N)
Возврат:
- y: тензор формы (M, N), который является копией x, за исключением того, что минимальное значение в каждой строке заменяется на 0.
Намекнули на это клон а также аргмин должен быть использован.
У меня проблемы с пониманием того, как это сделать без цикла, и мой текущий код ниже (хотя он правильно решает проблему) груб. Я ищу лучший способ решить эту проблему.
x0 = torch.tensor([[10, 20, 30], [2, 5, 1]])
x1 = torch.tensor([[2, 5, 10, -1], [1, 3, 2, 4], [5, 6, 2, 10]])
func(x0)
func(x1)
def func(x):
y = None
# g = x.argmin(dim=1)
g = x.min(dim=1)[1]
if x.shape[0] == 2:
x[0,:][g[0]] = 0
x[1,:][g[1]] = 0
elif x.shape[0] == 3:
x[0,:][g[0]] = 0
x[1,:][g[1]] = 0
x[2,:][g[2]] = 0
y = x
return y
1 ответ
Небольшой обзор кода, но это должно сработать:
def zero_min(x):
y = x.clone()
y[torch.arange(x.shape[0]), torch.argmin(x, dim=1)] = 0
return y
В каждой строке, если минимум не уникален, обнуляется только вхождение с наименьшим индексом.
Чтобы обнулить все вхождения, вы можете сделать что-то вроде следующего:
def zero_em_all_min(x):
y = x.clone()
y[x == x.min(dim=1, keepdims=True).values] = 0
return y
Спасибо, не подумал использовать
torch.arange
для перебора строк. Спрошу на штатном SE в будущем.— Райан
@Ryan Я рад, что на него можно было ответить, но, возможно, это действительно лучшая платформа для аналогичного вопроса!
— Эндрю