Source code for newsreclib.models.components.losses

from pytorch_metric_learning import losses
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu


[docs]class SupConLoss(losses.SupConLoss): def __init__(self, temperature=0.1, **kwargs): super().__init__() self.temperature = temperature self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False)
[docs] def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): # Overwrite original method to use directly similarity matrix instead of embeddings if all(len(x) <= 1 for x in indices_tuple): return self.zero_losses() mat = embeddings return self.loss_method(mat, indices_tuple)
def _compute_loss(self, mat, pos_mask, neg_mask): if pos_mask.bool().any() and neg_mask.bool().any(): mat = mat / self.temperature mat_max, _ = mat.max(dim=1, keepdim=True) mat = mat - mat_max.detach() # for numerical stability denominator = lmu.logsumexp( mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 ) log_prob = mat - denominator mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / ( pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) ) return { "loss": { "losses": -mean_log_prob_pos, "indices": c_f.torch_arange_from_size(mat), "reduction_type": "element", } } return self.zero_losses()