laser_weeding/losses/__init__.py

17 lines
589 B
Python

import torch.nn as nn
from .losses import CustormLoss
AVAI_LOSS = {'ce': nn.CrossEntropyLoss, 'multi_label_soft_margin': nn.MultiLabelSoftMarginLoss,
'test_custom': CustormLoss, 'mse': nn.MSELoss}
def get_losses(losses):
loss_dict = {}
for name in losses:
assert name in AVAI_LOSS, print('{name} is not supported, please implement it first.'.format(name=name))
if losses[name].params is not None:
loss_dict[name] = AVAI_LOSS[name](**losses[name].params)
else:
loss_dict[name] = AVAI_LOSS[name]()
return loss_dict