Source code for newsreclib.models.components.encoders.user.nrms

import torch
import torch.nn as nn

from newsreclib.models.components.layers.attention import AdditiveAttention


[docs]class UserEncoder(nn.Module): """Implements the user encoder of NRMS. Reference: Wu, Chuhan, Fangzhao Wu, Suyu Ge, Tao Qi, Yongfeng Huang, and Xing Xie. "Neural news recommendation with multi-head self-attention." In Proceedings of the 2019 conference on empirical methods in natural language processing and the 9th international joint conference on natural language processing (EMNLP-IJCNLP), pp. 6389-6394. 2019. For further details, please refer to the `paper <https://aclanthology.org/D19-1671/>`_ Attributes: news_embed_dim: The number of features in the news vector. num_heads: The number of heads in the ``MultiheadAttention``. query_dim: The number of features in the query vector. """ def __init__(self, news_embed_dim: int, num_heads: int, query_dim: int) -> None: super().__init__() # initialize self.multihead_attention = nn.MultiheadAttention( embed_dim=news_embed_dim, num_heads=num_heads ) self.additive_attention = AdditiveAttention(input_dim=news_embed_dim, query_dim=query_dim)
[docs] def forward(self, hist_news_vector: torch.Tensor) -> torch.Tensor: # batch_size, num_clicked_news_user, news_embed_dim user_vector, _ = self.multihead_attention( hist_news_vector, hist_news_vector, hist_news_vector ) # batch_size, news_embeding_dim user_vector = self.additive_attention(user_vector) return user_vector