Source code for newsreclib.models.components.encoders.user.cen_news_rec

# Adapted from https://github.com/taoqi98/FedNewsRec/blob/master/code/models.py

import torch
import torch.nn as nn

from newsreclib.models.components.layers.attention import AdditiveAttention


[docs]class UserEncoder(nn.Module): """Implements the user encoder of CenNewsRec. Reference: Qi, Tao, Fangzhao Wu, Chuhan Wu, Yongfeng Huang, and Xing Xie. "Privacy-Preserving News Recommendation Model Learning." In Findings of the Association for Computational Linguistics: EMNLP 2020, pp. 1423-1432. 2020. For further details, please refer to the `paper <https://aclanthology.org/2020.findings-emnlp.128/>`_ Attributes: num_filters: The number of input features in the ``MultiheadAttention`` num_heads: The number of heads in the ``MultiheadAttention``. query_dim: The number of features in the query vector. gru_hidden_dim: The number of features in the hidden state of the ``GRU``. num_recent_news: Number of recent news to be encoded in the short-term user representation. dropout_probability: Dropout probability. """ def __init__( self, num_filters: int, num_heads: int, query_dim: int, gru_hidden_dim: int, num_recent_news: int, dropout_probability: float, ) -> None: super().__init__() if not isinstance(num_recent_news, int): raise ValueError( f"Expected keyword argument `num_recent_news` to be an `int` but got {num_recent_news}" ) if not isinstance(dropout_probability, float): raise ValueError( f"Expected keyword argument `dropout_probability` to be a `float` but got {dropout_probability}" ) # initialize self.num_recent_news = num_recent_news self.multihead_attention = nn.MultiheadAttention( embed_dim=num_filters, num_heads=num_heads ) self.additive_attention = AdditiveAttention(input_dim=num_filters, query_dim=query_dim) self.gru = nn.GRU(input_size=num_filters, hidden_size=gru_hidden_dim, batch_first=True) self.final_additive_attention = AdditiveAttention( input_dim=gru_hidden_dim, query_dim=query_dim ) self.dropout = nn.Dropout(p=dropout_probability)
[docs] def forward(self, hist_news_vector: torch.Tensor) -> torch.Tensor: # long-term user representation # batch_size, num_clicked_news, num_filters longterm_user_vector, _ = self.multihead_attention( hist_news_vector, hist_news_vector, hist_news_vector ) longterm_user_vector = self.dropout(longterm_user_vector) # batch_size, num_filters longterm_user_vector = self.additive_attention(longterm_user_vector) # short-term user representation # batch_size, num_recent_news, num_filters recent_hist_news = hist_news_vector[:, -self.num_recent_news :, :] # 1, batch_size, gru_hidden_dim _, hidden = self.gru(recent_hist_news) # batch_size, gru_hidden_dim shortterm_user_vector = hidden.squeeze(dim=0) # aggregated user representation # batch_size, 2, gru_hidden_dim user_vector = torch.stack([shortterm_user_vector, longterm_user_vector], dim=1) # batch_size, gru_hidden_dim user_vector = self.final_additive_attention(user_vector) return user_vector