Source code for newsreclib.models.components.encoders.news.news

from typing import Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from newsreclib.models.components.layers.attention import AdditiveAttention


[docs]class NewsEncoder(nn.Module): """Implements a news encoder. Attributes: dataset_attributes: List of news features available in the used dataset. attributes2encode: List of news features used as input to the news encoder. concatenate_inputs: Whether the inputs (e.g., title and abstract) were concatenated into a single sequence. text_encoder: The text encoder module. category_encoder: The category encoder module. entity_encoder: The entity encoder module. combine_vectors: Whether to aggregate the representations of multiple news features. combine_type: The type of aggregation to use for combining multiple news features representations. Choose between `add_att` (additive attention), `linear`, and `concat` (concatenate). input_dim: The number of input features in the aggregation layer. query_dim: The number of features in the query vector. output_dim: The number of features in the final news vector. """ def __init__( self, dataset_attributes: List[str], attributes2encode: List[str], concatenate_inputs: bool, text_encoder: Optional[nn.Module], category_encoder: Optional[nn.Module], entity_encoder: Optional[nn.Module], combine_vectors: bool, combine_type: Optional[str], input_dim: Optional[int], query_dim: Optional[int], output_dim: Optional[int], ) -> None: super().__init__() # at least one attribute is needed to encode the news assert len(dataset_attributes) > 0 # flags to determine which news attributes are encoded self.concatenate_inputs = concatenate_inputs self.encode_text = False self.encode_category = False self.encode_entity = False # text encoders text_encoder_cand = ["title", "abstract"] if ("title" in attributes2encode) or ("abstract" in attributes2encode): assert isinstance(text_encoder, nn.Module) if not self.concatenate_inputs: self.text_encoders = nn.ModuleDict( { name: text_encoder for name in ( set(dataset_attributes) & set(attributes2encode) & set(text_encoder_cand) ) } ) else: self.text_encoders = nn.ModuleDict({"text": text_encoder}) self.encode_text = True # category encoders categ_encoder_cand = ["category", "subcategory"] if ("category" in attributes2encode) or ("subcategory" in attributes2encode): assert isinstance(category_encoder, nn.Module) self.category_encoders = nn.ModuleDict( { name: category_encoder for name in ( set(dataset_attributes) & set(attributes2encode) & set(categ_encoder_cand) ) } ) self.encode_category = True # entity encoders entity_encoder_cand = ["title_entities", "abstract_entities"] if ("title_entities" in attributes2encode) or ("abstract_entities" in attributes2encode): assert isinstance(entity_encoder, nn.Module) if not self.concatenate_inputs: self.entity_encoders = nn.ModuleDict( { name: entity_encoder for name in ( set(dataset_attributes) & set(attributes2encode) & set(entity_encoder_cand) ) } ) else: self.entity_encoders = nn.ModuleDict({"entities": entity_encoder}) self.encode_entity = True if combine_vectors: assert isinstance(combine_type, str) self.combine_type = combine_type if self.combine_type == "add_att": assert isinstance(input_dim, int) and input_dim > 0 assert isinstance(query_dim, int) and query_dim > 0 self.combine_layer = AdditiveAttention(input_dim=input_dim, query_dim=query_dim) elif self.combine_type == "linear": assert isinstance(input_dim, int) and input_dim > 0 assert isinstance(output_dim, int) and output_dim > 0 self.combine_layer = nn.Linear(in_features=input_dim, out_features=output_dim) elif self.combine_type == "concat": self.combine_layer = lambda vectors: torch.cat(vectors, dim=1) else: raise ValueError( f"Expected keyword argument `combine_type` to be in [`add_att`, `linear`, `concat`] but got {self.combine_type}." )
[docs] def forward(self, news: Dict[str, torch.Tensor]) -> torch.Tensor: text_vectors = [] category_vectors = [] entity_vectors = [] if self.encode_text: text_vectors = [encoder(news[name]) for name, encoder in self.text_encoders.items()] if self.encode_category: category_vectors = [ encoder(news[name]) for name, encoder in self.category_encoders.items() ] if self.encode_entity: entity_vectors = [ encoder(news[name]) for name, encoder in self.entity_encoders.items() ] if self.encode_category and not self.encode_entity: all_vectors = text_vectors + category_vectors elif self.encode_entity and not self.encode_category: all_vectors = text_vectors + entity_vectors else: all_vectors = text_vectors + category_vectors + entity_vectors if len(all_vectors) == 1: news_vector = all_vectors[0] else: if self.combine_type == "add_att": news_vector = self.combine_layer(torch.stack(all_vectors, dim=1)) elif self.combine_type == "linear": text_vectors = text_vectors[0] if len(text_vectors) == 1 else text_vectors category_vectors = ( category_vectors[0] if len(category_vectors) == 1 else category_vectors ) entity_vectors = entity_vectors[0] if len(entity_vectors) == 1 else entity_vectors if self.encode_entity and not self.encode_category: all_vectors = torch.cat([text_vectors, entity_vectors], dim=-1) elif self.encode_category and not self.encode_entity: all_vectors = torch.cat([text_vectors, category_vectors], dim=-1) else: all_vectors = torch.cat( [text_vectors, category_vectors, entity_vectors], dim=-1 ) news_vector = self.combine_layer(all_vectors) else: news_vector = self.combine_layer(all_vectors) return news_vector
[docs]class KCNN(nn.Module): """Implements the knowledge-aware CNN from DKN. Reference: Wang, Hongwei, Fuzheng Zhang, Xing Xie, and Minyi Guo. "DKN: Deep knowledge-aware network for news recommendation." In Proceedings of the 2018 world wide web conference, pp. 1835-1844. 2018. For further details, please refer to the `paper <https://dl.acm.org/doi/10.1145/3178876.3186175>`_ Attributes: pretrained_text_embeddings: Matrix of pretrained text embeddings. pretrained_entity_embeddings: Matrix of pretrained entity embeddings. pretrained_context_embeddings: Matrix of pretrained context embeddings. use_context: Whether to use context embeddings. text_embed_dim: The number of features in the text vector. entity_embed_dim: The number of features in the entity vector. num_filters: The number of filters in the ``CNN``. window_sizes: List of window sizes for the ``CNN``. """ def __init__( self, pretrained_text_embeddings: torch.Tensor, pretrained_entity_embeddings: torch.Tensor, pretrained_context_embeddings: Optional[torch.Tensor], use_context: bool, text_embed_dim: int, entity_embed_dim: int, num_filters: int, window_sizes: List[int], ) -> None: super().__init__() # initialize self.window_sizes = window_sizes self.text_embedding_layer = nn.Embedding.from_pretrained( torch.FloatTensor(pretrained_text_embeddings), freeze=False, padding_idx=0 ) self.entity_embedding_layer = nn.Embedding.from_pretrained( torch.FloatTensor(pretrained_entity_embeddings), freeze=False, padding_idx=0 ) self.use_context = use_context if self.use_context: assert isinstance(pretrained_context_embeddings, torch.Tensor) self.context_embedding_layer = nn.Embedding.from_pretrained( torch.FloatTensor(pretrained_context_embeddings), freeze=False, padding_idx=0 ) self.transform_matrix = nn.Parameter( torch.empty(entity_embed_dim, text_embed_dim).uniform_(-0.1, 0.1) ) self.transform_bias = nn.Parameter(torch.empty(text_embed_dim).uniform_(-0.1, 0.1)) self.conv_filters = nn.ModuleDict( { str(x): nn.Conv2d(3 if self.use_context else 2, num_filters, (x, text_embed_dim)) for x in self.window_sizes } )
[docs] def forward(self, news: Dict[str, torch.Tensor]) -> torch.Tensor: # batch_size, num_words_text, text_embed_dim text_vector = self.text_embedding_layer(news["title"]) # batch_size, num_words_text, entity_embed_dim entity_vector = self.entity_embedding_layer(news["title_entities"]) # batch_size, num_words_text, text_embed_dim transformed_entity_vector = torch.tanh( torch.add(torch.matmul(entity_vector, self.transform_matrix), self.transform_bias) ) if self.use_context: # batch_size, num_words_text, entity_embed_dim context_vector = self.context_embedding_layer(news["title_entities"]) # batch_size, num_words_text, entity_embed_dim transformed_context_vector = torch.tanh( torch.add(torch.matmul(context_vector, self.transform_matrix), self.transform_bias) ) # batch_size, 3, num_words_text, text_embedding_layer multi_channel_vector = torch.stack( [text_vector, transformed_entity_vector, transformed_context_vector], dim=1 ) else: # batch_size, 2, num_words_text, text_embed_dim multi_channel_vector = torch.stack([text_vector, transformed_entity_vector], dim=1) pooled_vectors = [] for size in self.window_sizes: # batch_size, num_filters, num_words_text + 1 - size convoluted = self.conv_filters[str(size)](multi_channel_vector).squeeze(dim=3) activated = F.relu(convoluted) # batch_size, num_filters pooled = activated.max(dim=-1)[0] pooled_vectors.append(pooled) # batch_size, len(window_sizes) * num_filters news_vector = torch.cat(pooled_vectors, dim=1) return news_vector