Привет всем, я пытаюсь уменьшить сложность моего кода Python. Приведенная ниже функция предназначена для вычисления потерь при проверке и тестировании для различных моделей прогнозирования временных рядов PyTorch. Я не буду вдаваться во все тонкости, но мне нужно поддерживать модели, которые возвращают несколько целей, выходное распределение + стандартное выражение (в отличие от одного тензора) и модели, требующие замаскированных элементов целевой последовательности. Со временем это привело к длительным блокировкам if else и множеству других плохих приемов.
Раньше я использовал словари для сопоставления длинных операторов if else, но из-за вложенности этого кода не похоже, что здесь он будет хорошо работать. Я также не вижу смысла просто создавать больше функций, поскольку это просто перемещает операторы if else в другое место и требует передачи большего количества параметров. У кого-нибудь есть идеи? Сейчас в этом коде есть несколько модульных тестов, которые запускаются разными путями. Однако читать его по-прежнему неудобно. Кроме того, скоро у меня будет еще больше вариантов моделей, которые я буду расширять и поддерживать. Полный код в контексте можно увидеть по этой ссылке.
def compute_validation(validation_loader: DataLoader,
model,
epoch: int,
sequence_size: int,
criterion: Type[torch.nn.modules.loss._Loss],
device: torch.device,
decoder_structure=False,
meta_data_model=None,
use_wandb: bool = False,
meta_model=None,
multi_targets=1,
val_or_test="validation_loss",
probabilistic=False) -> float:
"""Function to compute the validation loss metrics
:param validation_loader: The data-loader of either validation or test-data
:type validation_loader: DataLoader
:param model: model
:type model: [type]
:param epoch: The epoch where the validation/test loss is being computed.
:type epoch: int
:param sequence_size: The number of historical time steps passed into the model
:type sequence_size: int
:param criterion: The evaluation metric function
:type criterion: Type[torch.nn.modules.loss._Loss]
:param device: The device
:type device: torch.device
:param decoder_structure: Whether the model should use sequential decoding, defaults to False
:type decoder_structure: bool, optional
:param meta_data_model: The model to handle the meta-data, defaults to None
:type meta_data_model: PyTorchForecast, optional
:param use_wandb: Whether Weights and Biases is in use, defaults to False
:type use_wandb: bool, optional
:param meta_model: Whether the model leverages meta-data, defaults to None
:type meta_model: bool, optional
:param multi_targets: Whether the model, defaults to 1
:type multi_targets: int, optional
:param val_or_test: Whether validation or test loss is computed, defaults to "validation_loss"
:type val_or_test: str, optional
:param probabilistic: Whether the model is probablistic, defaults to False
:type probabilistic: bool, optional
:return: The loss of the first metric in the list.
:rtype: float
"""
print('Computing validation loss')
unscaled_crit = dict.fromkeys(criterion, 0)
scaled_crit = dict.fromkeys(criterion, 0)
model.eval()
output_std = None
multi_targs1 = multi_targets
scaler = None
if validation_loader.dataset.no_scale:
scaler = validation_loader.dataset
with torch.no_grad():
i = 0
loss_unscaled_full = 0.0
for src, targ in validation_loader:
src = src if isinstance(src, list) else src.to(device)
targ = targ if isinstance(targ, list) else targ.to(device)
# targ = targ if isinstance(targ, list) else targ.to(device)
i += 1
if decoder_structure:
if type(model).__name__ == "SimpleTransformer":
targ_clone = targ.detach().clone()
output = greedy_decode(
model,
src,
targ.shape[1],
targ_clone,
device=device)[
:,
:,
0]
elif type(model).__name__ == "Informer":
multi_targets = multi_targs1
filled_targ = targ[1].clone()
pred_len = model.pred_len
filled_targ[:, -pred_len:, :] = torch.zeros_like(filled_targ[:, -pred_len:, :]).float().to(device)
output = model(src[0].to(device), src[1].to(device), filled_targ.to(device), targ[0].to(device))
labels = targ[1][:, -pred_len:, 0:multi_targets]
src = src[0]
multi_targets = False
else:
output = simple_decode(model=model,
src=src,
max_seq_len=targ.shape[1],
real_target=targ,
output_len=sequence_size,
multi_targets=multi_targets,
probabilistic=probabilistic,
scaler=scaler)
if probabilistic:
output, output_std = output[0], output[1]
output, output_std = output[:, :, 0], output_std[0]
output_dist = torch.distributions.Normal(output, output_std)
else:
if probabilistic:
output_dist = model(src.float())
output = output_dist.mean.detach().numpy()
output_std = output_dist.stddev.detach().numpy()
else:
output = model(src.float())
if multi_targets == 1:
labels = targ[:, :, 0]
elif multi_targets > 1:
labels = targ[:, :, 0:multi_targets]
validation_dataset = validation_loader.dataset
for crit in criterion:
if validation_dataset.scale:
# Should this also do loss.item() stuff?
if len(src.shape) == 2:
src = src.unsqueeze(0)
src1 = src[:, :, 0:multi_targets]
loss_unscaled_full = compute_loss(labels, output, src1, crit, validation_dataset,
probabilistic, output_std, m=multi_targets)
unscaled_crit[crit] += loss_unscaled_full.item() * len(labels.float())
loss = compute_loss(labels, output, src, crit, False, probabilistic, output_std, m=multi_targets)
scaled_crit[crit] += loss.item() * len(labels.float())
if use_wandb:
if loss_unscaled_full:
scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
newD = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in unscaled_crit.items()}
wandb.log({'epoch': epoch,
val_or_test: scaled,
"unscaled_" + val_or_test: newD})
else:
scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
wandb.log({'epoch': epoch, val_or_test: scaled})
model.train()
return list(scaled_crit.values())[0]