Source code for newsreclib.models.components.utils

import torch


[docs]def pairwise_cosine_similarity( x: torch.Tensor, y: torch.Tensor, zero_diagonal: bool = False ) -> torch.Tensor: """ Implemented from https://github.com/duynguyen-0203/miner/blob/master/src/utils.py Calculates the pairwise cosine similarity matrix. Args: x: (batch_size, M, d) y: (batch_size, N, d) zero_diagonal: Determines if the diagonal of the distance matrix should be set to zero. Returns: A single-value tensor with the pairwise cosine similarity between ``x`` and ``y``. """ x_norm = torch.linalg.norm(x, dim=2, keepdim=True) y_norm = torch.linalg.norm(y, dim=2, keepdim=True) x_norm[torch.where(x_norm) == 0.0] = 1 y_norm[torch.where(y_norm) == 0.0] = 1 x = torch.div(x, 1e-8 + x_norm) y = torch.div(y, 1e-8 + y_norm) distance = torch.matmul(x, y.permute(0, 2, 1)) if zero_diagonal: assert x.shape[1] == y.shape[1] mask = torch.eye(x.shape[1]).repeat(x.shape[0], 1, 1).bool().to(distance.device) distance.masked_fill_(mask, 0) return distance