Source code for newsreclib.models.components.layers.attention

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class AdditiveAttention(nn.Module): def __init__(self, input_dim: int, query_dim: int) -> None: super().__init__() if not isinstance(input_dim, int): raise ValueError( f"Expected keyword argument `input_dim` to be an `int` but got {input_dim}" ) if not isinstance(query_dim, int): raise ValueError( f"Expected keyword argument `query_dim` to be an `int` but got {query_dim}" ) # initialize self.linear = nn.Linear(in_features=input_dim, out_features=query_dim) self.query = nn.Parameter(torch.empty(query_dim).uniform_(-0.1, 0.1))
[docs] def forward(self, input_vector: torch.Tensor) -> torch.Tensor: """ Args: input_vector: User tensor of shape `(batch_size, hidden_dim, output_dim)`. Returns: User tensor of shape `(batch_size, news_emb_dim)`. """ # batch_size, hidden_dim, output_dim attention = torch.tanh(self.linear(input_vector)) # batch_size, hidden_dim attention_weights = F.softmax(torch.matmul(attention, self.query), dim=1) # batch_size, output_dim weighted_input = torch.bmm(attention_weights.unsqueeze(dim=1), input_vector).squeeze(dim=1) return weighted_input
[docs]class PolyAttention(nn.Module): """Implementation of Poly attention scheme (used in MINER) that extracts K attention vectors through K additive attentions. Adapted from https://github.com/duynguyen-0203/miner/blob/master/src/model/model.py. Reference: Li, Jian, Jieming Zhu, Qiwei Bi, Guohao Cai, Lifeng Shang, Zhenhua Dong, Xin Jiang, and Qun Liu. "MINER: multi-interest matching network for news recommendation." In Findings of the Association for Computational Linguistics: ACL 2022, pp. 343-352. 2022. For further details, please refer to the `paper <https://aclanthology.org/2022.findings-acl.29/>`_ """ def __init__(self, input_dim: int, num_context_codes: int, context_code_dim: int) -> None: """ Args: input_dim: The number of expected features in the input. num_context_codes: The number of attention vectors. context_code_dim: The number of features in a context code. """ super().__init__() if not isinstance(input_dim, int): raise ValueError( f"Expected keyword argument `input_dim` to be an `int` but got {input_dim}" ) if not isinstance(num_context_codes, int): raise ValueError( f"Expected keyword argument `num_context_codes` to be an `int` but got {num_context_codes}" ) if not isinstance(context_code_dim, int): raise ValueError( f"Expected keyword argument `context_code_dim` to be an `int` but got {context_code_dim}" ) # initialize self.linear = nn.Linear(in_features=input_dim, out_features=context_code_dim, bias=False) self.context_codes = nn.Parameter( nn.init.xavier_uniform_( torch.empty(num_context_codes, context_code_dim), gain=nn.init.calculate_gain("tanh"), ) )
[docs] def forward( self, embeddings: torch.Tensor, attn_mask: torch.Tensor, bias: torch.Tensor = None ): """ Args: embeddings: `(batch_size, hist_length, embed_dim)` attn_mask: `(batch_size, hist_length)` bias: `(batch_size, hist_length, num_candidates)` Returns: torch.Tensor: `(batch_size, num_context_codes, embed_dim)` """ projection = torch.tanh(self.linear(embeddings)) if bias is None: weights = torch.matmul(projection, self.context_codes.T) else: bias = bias.mean(dim=2).unsqueeze(dim=2) weights = torch.matmul(projection, self.context_codes.T) + bias weights = weights.permute(0, 2, 1) weights = weights.masked_fill_(~attn_mask.unsqueeze(dim=1), 1e-30) weights = F.softmax(weights, dim=2) poly_news_vector = torch.matmul(weights, embeddings) return poly_news_vector
[docs]class TargetAwareAttention(nn.Module): """Implementation of the target-aware attention network used in MINER. Adapted from https://github.com/duynguyen-0203/miner/blob/master/src/model/model.py Reference: Li, Jian, Jieming Zhu, Qiwei Bi, Guohao Cai, Lifeng Shang, Zhenhua Dong, Xin Jiang, and Qun Liu. "MINER: multi-interest matching network for news recommendation." In Findings of the Association for Computational Linguistics: ACL 2022, pp. 343-352. 2022. For further details, please refer to the `paper <https://aclanthology.org/2022.findings-acl.29/>`_ """ def __init__(self, input_dim: int) -> None: """ Args: input_dim: The number of features in the query and key vectors. """ super().__init__() if not isinstance(input_dim, int): raise ValueError( f"Expected keyword argument `input_dim` to be an `int` but got {input_dim}" ) # initialize self.linear = nn.Linear(in_features=input_dim, out_features=input_dim, bias=False)
[docs] def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: """ Args: query: `(batch_size, num_context_codes, input_embed_dim)` key: `(batch_size, num_candidates, input_embed_dim)` value: `(batch_size, num_candidates, num_context_codes)` """ projection = F.gelu(self.linear(query)) weights = F.softmax(torch.matmul(key, projection.permute(0, 2, 1)), dim=2) outputs = torch.mul(weights, value).sum(dim=2) return outputs
[docs]class DenseAttention(nn.Module): """Dense attention used in CAUM. Reference: Qi, Tao, Fangzhao Wu, Chuhan Wu, and Yongfeng Huang. "News recommendation with candidate-aware user modeling." In Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 1917-1921. 2022. For further details, please refer to the `paper <https://dl.acm.org/doi/abs/10.1145/3477495.3531778>`_ """ def __init__(self, input_dim: int, hidden_dim1: int, hidden_dim2: int) -> None: super().__init__() if not isinstance(input_dim, int): raise ValueError( f"Expected keyword argument `input_dim` to be an `int` but got {input_dim}" ) if not isinstance(hidden_dim1, int): raise ValueError( f"Expected keyword argument `hidden_dim1` to be an `int` but got {hidden_dim1}" ) if not isinstance(hidden_dim2, int): raise ValueError( f"Expected keyword argument `hidden_dim2` to be an `int` but got {hidden_dim2}" ) # initialize self.linear = nn.Linear(input_dim, hidden_dim1) self.tanh1 = nn.Tanh() self.linear2 = nn.Linear(hidden_dim1, hidden_dim2) self.tanh2 = nn.Tanh() self.linear3 = nn.Linear(hidden_dim2, 1)
[docs] def forward(self, input_vector: torch.Tensor) -> torch.Tensor: transformed_vector = self.linear(input_vector) transformed_vector = self.tanh1(transformed_vector) transformed_vector = self.linear2(transformed_vector) transformed_vector = self.tanh2(transformed_vector) transformed_vector = self.linear3(transformed_vector) return transformed_vector
[docs]class PersonalizedAttention(nn.Module): """Personalized attention used in NPA. Reference: Wu, Chuhan, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. "NPA: neural news recommendation with personalized attention." In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 2576-2584. 2019. For further details, please refer to the `paper <https://dl.acm.org/doi/abs/10.1145/3292500.3330665>`_ """ def __init__(self, preference_query_dim: int, num_filters: int) -> None: super().__init__() if not isinstance(preference_query_dim, int): raise ValueError( f"Expected keyword argument `preference_query_dim` to be an `int` but got {preference_query_dim}" ) if not isinstance(num_filters, int): raise ValueError( f"Expected keyword argument `num_filters` to be an `int` but got {num_filters}" ) # initialize self.preference_query_projection = nn.Linear(preference_query_dim, num_filters)
[docs] def forward(self, query: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: """ Args: query: `(batch_size * preference_dim)` keys: `(batch_size * num_filters * num_words_text)` Returns: `(batch_size * num_filters)` """ # batch_size * 1 * num_filters query = torch.tanh(self.preference_query_projection(query).unsqueeze(dim=1)) # batch_size * 1 * num_words_text attn_results = torch.bmm(query, keys) # batch_size * num_words_text * 1 attn_weights = F.softmax(attn_results, dim=2).permute(0, 2, 1) # batch_size * num_filters attn_aggr = torch.bmm(keys, attn_weights).squeeze() return attn_aggr