Welcome to NewsRecLib’s documentation!

NewsRecLib is a library based on PyTorch Lightning and Hydra for the development and evaluation of neural news recommenders (NNR). The framework is highly configurable and modularized, decoupling core model components from one another. It enables running experiments from a single configuration file that navigates the pipeline from dataset selection and loading to model evaluation. NewsRecLib provides implementations of several neural news recommenders, training methods, standard evaluation benchmarks, hypeparameter optimization algorithms, extensive logging functionalities, and evaluation metrics (ranging from accuracy-based to beyond accuracy performance evaluation).

The foremost goals of NewsRecLib are to promote reproducible research and rigorous experimental evaluation.

system schema

NewsrecLib’s schema

Introduction

NewsRecLib is a library based on PyTorch Lightning and Hydra for the development and evaluation of neural news recommenders (NNR). The framework is highly configurable and modularized, decoupling core model components from one another. It enables running experiments from a single configuration file that navigates the pipeline from dataset selection and loading to model evaluation. NewsRecLib provides implementations of several neural news recommenders, training methods, standard evaluation benchmarks, hypeparameter optimization algorithms, extensive logging functionalities, and evaluation metrics (ranging from accuracy-based to beyond accuracy performance evaluation).

The foremost goals of NewsRecLib are to promote reproducible research and rigorous experimental evaluation.

NewsRecLib' schema

NewsRecLib’s schema

Installation

NewsRecLib requires Python version 3.8 or later.

NewsRecLib requires PyTorch, PyTorch Lightning, and TorchMetrics version 2.0 or later. If you want to use NewsRecLib with GPU, please ensure CUDA or cudatoolkit version of 11.8.

Install from source

CONDA

git clone https://github.com/andreeaiana/newsreclib.git
cd newsreclib
conda create --name newsreclib_env python=3.8
conda activate newsreclib_env
pip install -e .

Quick Start

NewsRecLib’s entry point is the function train, which accepts a configuration file that drives the entire experiment.

Basic Configuration

The following example shows how to train a NRMS model on the MINDsmall dataset with the original configurations (i.e., news encoder contextualizing pretrained embeddings, model trained by optimizing cross-entropy loss), using an existing configuration file.

python newsreclib/train.py experiment=nrms_mindsmall_pretrainedemb_celoss_bertsent

In the basic experiment, the experiment configuration only specifies required hyperparameter values which are not set in the configurations of the corresponding modules.

defaults:
    - override /data: mind_rec_bert_sent.yaml
    - override /model: nrms.yaml
    - override /callbacks: default.yaml
    - override /logger: many_loggers.yaml
    - override /trainer: gpu.yaml
data:
    dataset_size: "small"
model:
    use_plm: False
    pretrained_embeddings_path: ${paths.data_dir}MINDsmall_train/transformed_word_embeddings.npy
    embed_dim: 300
    num_heads: 15

Advanced Configuration

The advanced scenario depicts a more complex experimental setting. Users cn overwrite from the main experiment configuration file any of the predefined module configurations. The following code snippet shows how to train a NRMS model with a PLM-based news encoder, and a supervised contrastive loss objective instead of the default settings.

python newsreclib/train.py experiment=nrms_mindsmall_plm_supconloss_bertsent

This is achieved by creating an experiment configuration file with the following specifications:

defaults:
    - override /data: mind_rec_bert_sent.yaml
    - override /model: nrms.yaml
    - override /callbacks: default.yaml
    - override /logger: many_loggers.yaml
    - override /trainer: gpu.yaml
data:
    dataset_size: "small"
    use_plm: True
    tokenizer_name: "roberta-base"
    tokenizer_use_fast: True
    tokenizer_max_len: 96
model:
    loss: "sup_con_loss"
    temperature: 0.1
    use_plm: True
    plm_model: "roberta-base"
    frozen_layers: [0, 1, 2, 3, 4, 5, 6, 7]
    pretrained_embeddings_path: None
    embed_dim: 768
    num_heads: 16

Alternatively, configurations can be overridden from the command line, as follows:

python newsreclib/train.py experiment=nrms_mindsmall_plm_supconloss_bertsent data.batch_size=128

Summary of the Datasets

NewsRecLib integrates, to date, 2 benchmark datasets: MIND and Adressa. Each is supported in two variants, depending on the dataset size.

MIND Dataset

NewsRecLib provides downloading, parsing, annotation, and loading functionalities for two variants of the MIND: MINDsmall and MINDlarge.

Reference: Wu, Fangzhao, Ying Qiao, Jiun-Hung Chen, Chuhan Wu, Tao Qi, Jianxun Lian, Danyang Liu et al. “Mind: A large-scale dataset for news recommendation.” In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pp. 3597-3606. 2020.

For further details, please refer to the paper

Adreesa Dataset

NewsRecLib provides downloading, parsing, annotation, and loading functionalities for two variants of the Adressa: 1-week and 3-month.

Reference: Gulla, Jon Atle, Lemei Zhang, Peng Liu, Özlem Özgöbek, and Xiaomeng Su. “The adressa dataset for news recommendation.” In Proceedings of the international conference on web intelligence, pp. 1042-1048. 2017.

For further details, please refer to the paper

Summary of the Recommendation Models

NewsRecLib integrates, to date, 13 recommendation models, partitioned into two classes: general recommenders (GeneralRec) and fairness-aware recommenders (FairRec).

All the recommendation models inherit from a common abstract class:

  • GeneralRec

  • FairRec

Click Behavior Fusion

NewsRecLib supports 2 strategies for aggregating users’ click behaviors: early fusion and late fusion.

Early Fusion

This is the predominant paradigm used in all recommendation models. It involves aggregating the representations of clicked news (i.e., building an explicit user representation) before comparison with the recommendation candidate.

When choosing this option, users will have to select one of the available user encoders or implement a new one.

Late Fusion

This light-weight approach replaces user encoders with the mean-pooling of dot-product scores between the embedding of the candidate \(n^c\) and the embeddings of the clicked news \(n_i^u\).

Given a candidate news \(n^c\) and a sequence of news clicked by the user \(H = n_1^u, ..., n_N^u\), the relevance score of the candidate news with regards to the user \(u\)’s history is computed as: \(s(\mathbf{n}^c, u) = \frac{1}{N} \sum_{i=1}^N \mathbf{n}^c \cdot \mathbf{n}_i^u\), where \(\mathbf{n}\) denotes the embedding of a news and \(N\) the history length.

For further details, please refer to the paper

Reference: Iana, Andreea, Goran Glavas, and Heiko Paulheim. “Simplifying Content-Based Neural News Recommendation: On User Modeling and Training Objectives.” In Proceedings of the 46th International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 2384-2388. 2023.

Training Objectives

NewsRecLib supports 3 training objectives: point-wise classification, contrastive learning objectives, and dual training objectives.

Point-wise classification objectives

NewsRecLib implements model training with Cross-Entropy Loss. as the most standard classification objective.

Contrastive-learning objectives

NewsRecLib implements Supervised Contrastive Loss as contrastive-learning objective.

Reference: Khosla, Prannay, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and Dilip Krishnan. “Supervised contrastive learning.” Advances in neural information processing systems 33 (2020): 18661-18673.

Dual training objectives

Models can also be trained with a dual learning objective, which combines cross-entropy loss \(\mathcal{L}_{CE}\) and supervised contrastive loss \(\mathcal{L}_{SCL}\) with a weighted average.

\(\mathcal{L} = (1-\lambda) \mathcal{L}_{CE} + \lambda \mathcal{L}_{SCL}\)

Reference: Gunel, Beliz, Jingfei Du, Alexis Conneau, and Ves Stoyanov. “Supervised contrastive learning for pre-trained language model fine-tuning.” arXiv preprint arXiv:2011.01403 (2020).

Recommenders

GeneralRec

Summary

CAUM

CenNewsRec

DKN

LSTUR

MINER

MINS

NAML

NPA

NRMS

TANR

FairRec

Summary

MANNER: A-Module

MANNER: CR-Module

MANNER

SentiDebias

SentiRec

Callbacks

Early Stopping

NewsRecLib integrates the early stopping functionality supported by PyTorch. We adopt the same notation as PyTorch.

For more details, please refer to the corresponding PyTorch early stopping page.

This is an example that shows all the options.

early_stopping:
  _target_: lightning.pytorch.callbacks.EarlyStopping
  monitor: ??? # quantity to be monitored, must be specified !!!
  min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
  patience: 3 # number of checks with no improvement after which training will be stopped
  verbose: False # verbosity mode
  mode: "min" # "max" means higher metric value is better, can be also "min"
  strict: True # whether to crash the training if monitor is not found in the validation metrics
  check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
  stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
  divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
  check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
  log_rank_zero_only: False  # this keyword argument isn't available in stable version

Model Checkpointing

NewsRecLib integrates the checkpointing functionality supported by PyTorch. We adopt the same notation as PyTorch.

For more details, please refer to the corresponding PyTorch checkpointing page.

This is an example that shows all the options.

model_checkpoint:
  _target_: lightning.pytorch.callbacks.ModelCheckpoint
  dirpath: null # directory to save the model file
  filename: null # checkpoint filename
  monitor: null # name of the logged metric which determines when model is improving
  verbose: False # verbosity mode
  save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
  save_top_k: 1 # save k best models (determined by above metric)
  mode: "min" # "max" means higher metric value is better, can be also "min"
  auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
  save_weights_only: False # if True, then only the model’s weights will be saved
  every_n_train_steps: null # number of training steps between checkpoints
  train_time_interval: null # checkpoints are monitored at the specified time interval
  every_n_epochs: null # number of epochs between checkpoints
  save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation

Hyperparameter Optimization

NewsRecLib supports hyperparameter optimization by integrating the functionalities of the Optuna library through the Optuna Sweeper plugin of Hydra.

This is an example that shows how to perform hyperparameter optimization.

defaults:
  - override /hydra/sweeper: optuna
optimized_metric: "val/acc_best"
hydra:
  mode: "MULTIRUN"
  sweeper:
    _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
    storage: null
    study_name: null
    n_jobs: 1
    direction: minimize
    n_trials: 20
  sampler:
    _target_: optuna.samplers.TPESampler
    seed: 1234
    n_startup_trials: 10
  params:
    data.neg_sampling_ratio: range(1, 10, step=1)
    model.lr: choice(1e-4, 1e-5, 1e-6)
    model.temperature: interval(0.0, 1.0)

Metrics

NewsRecLib provides several evaluation metrics, for evaluating recommendation models on the following dimensions: classification, ranking, diversity, and personalization. Note that NewsRecLib relies on TorchMetrics for the metric implementation. Custom metrics are built by extending the Metric class.

The user can add any metric available in All TorchMetrics or implement a new one, following this guide.

newsreclib package

Subpackages

newsreclib.data package

Subpackages
newsreclib.data.components package
Submodules
newsreclib.data.components.adressa_dataframe module
newsreclib.data.components.adressa_user_info module
class newsreclib.data.components.adressa_user_info.UserInfo(train_date_split: int, test_date_split: int)[source]

Bases: object

train_date_split

A string with the date before which click behaviors are included in the history of a user.

test_date_split

A string with the date after which click behaviors are included in the test set.

sort_click()[source]

Sorts user clicks by time in ascending order.

update(nindex: int, click_time: int, date: str)[source]
Parameters:
  • nindex – The index of a news article.

  • click_time – The time when the user clicked on the news article.

  • date – The processed click time used to assign the sample into the history of the user, the train or the test set.

newsreclib.data.components.batch module
class newsreclib.data.components.batch.NewsBatch(*args, **kwargs)[source]

Bases: dict

Batch used for reshaping the embedding space based on an aspect of the news.

Reference: Iana, Andreea, Goran Glavaš, and Heiko Paulheim. “Train Once, Use Flexibly: A Modular Framework for Multi-Aspect Neural News Recommendation.” arXiv preprint arXiv:2307.16089 (2023). https://arxiv.org/pdf/2307.16089.pdf

news

Dictionary mapping features of news to values.

Type:

Dict[str, Any]

labels

Labels of news based on the specified aspect.

Type:

torch.Tensor

labels: Tensor
news: Dict[str, Any]
class newsreclib.data.components.batch.RecommendationBatch(*args, **kwargs)[source]

Bases: dict

Batch used for recommendation.

batch_hist

Batch of histories of users.

Type:

torch.Tensor

batch_cand

Batch of candidates for each user.

Type:

torch.Tensor

x_hist

Dictionary of news from a the users’ history, mapping news features to values.

Type:

Dict[str, Any]

x_cand

Dictionary of news from a the users’ candidates, mapping news features to values.

Type:

Dict[str, Any]

labels

Ground truth specifying whether the news is relevant to the user.

Type:

torch.Tensor

users

Users included in the batch.

Type:

torch.Tensor

batch_cand: Tensor
batch_hist: Tensor
labels: Tensor
users: Tensor
x_cand: Dict[str, Any]
x_hist: Dict[str, Any]
newsreclib.data.components.data_utils module
newsreclib.data.components.download_utils module
newsreclib.data.components.file_utils module
newsreclib.data.components.file_utils.check_integrity(fpath: str) bool[source]

Checks whether a file exists.

newsreclib.data.components.file_utils.load_idx_map_as_dict(fpath: str) Dict[str, int][source]

Loads a table as dictionary.

newsreclib.data.components.file_utils.to_tsv(df: DataFrame, fpath: str) None[source]

Stores a dataframe in .tsv format.

newsreclib.data.components.mind_dataframe module
newsreclib.data.components.news_dataset module
newsreclib.data.components.rec_dataset module
newsreclib.data.components.sentiment_annotator module

Submodules
newsreclib.data.adressa_news_datamodule module
newsreclib.data.adressa_rec_datamodule module
newsreclib.data.mind_news_datamodule module
newsreclib.data.mind_rec_datamodule module
Module contents

newsreclib.metrics package

Submodules
newsreclib.metrics.base module
class newsreclib.metrics.base.CustomRetrievalMetric(empty_target_action: str = 'neg', ignore_index: Optional[int] = None, **kwargs: Any)[source]

Bases: Metric, ABC

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • cand_aspects (Tensor): A long tensor of shape (N, ...)

  • clicked_aspects (Tensor): A long tensor of shape (N, ...)

  • cand_indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

  • hist_indexes (Tensor): A long tensor of shape (N, ...) which indicate to which user a target belongs

Note

cand_indexes, preds and cand_aspects must have the same dimension and will be flatten

to single dimension once provided.

Note

Predictions will be first grouped by cand_indexes and then the real metric, defined by overriding the _metric method, will be computed as the mean of the scores over each query.

As output to forward and compute the metric returns the following output:

  • metric (Tensor): A tensor as computed by _metric if the number of positive targets is at least 1, otherwise behave as specified by self.empty_target_action.

Parameters:
  • empty_target_action

    Specify what to do with queries that do not have at least a positive or negative (depend on metric) target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index – Ignore predictions where the target is equal to this number.

  • kwargs – Additional keyword arguments, see Metric kwargs for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

cand_aspects: List[Tensor]
cand_indexes: List[Tensor]
clicked_aspects: List[Tensor]
compute() Tensor[source]

First concat state cand_indexes, hist_indexes, preds, cand_aspects, and clicked_aspects since they were stored as lists.

After that, compute list of groups that will help in keeping together predictions about the same query. Finally, for each group compute the _metric if the number of positive targets is at least 1, otherwise behave as specified by self.empty_target_action.

full_state_update: bool = False
higher_is_better: bool = True
hist_indexes: List[Tensor]
is_differentiable: bool = False
preds: List[Tensor]
update(preds: Tensor, cand_aspects: Tensor, clicked_aspects: Tensor, cand_indexes: Tensor, hist_indexes: Tensor) None[source]

Check shape, check and convert dtypes, flatten and add to accumulators.

newsreclib.metrics.diversity module
class newsreclib.metrics.diversity.Diversity(num_classes: int, empty_target_action: str = 'neg', ignore_index: Optional[int] = None, top_k: Optional[int] = None, **kwargs: Any)[source]

Bases: RetrievalMetric

Implementation of the Aspect-based Diversity.

Reference: Iana, Andreea, Goran Glavaš, and Heiko Paulheim. “Train Once, Use Flexibly: A Modular Framework for Multi-Aspect Neural News Recommendation.” arXiv preprint arXiv:2307.16089 (2023). https://arxiv.org/pdf/2307.16089.pdf

For further details, please refer to the paper

full_state_update: bool = False
higher_is_better: bool = True
is_differentiable: bool = False
newsreclib.metrics.functional module
newsreclib.metrics.functional.diversity(preds: Tensor, target: Tensor, num_classes: int, top_k: Optional[int] = None) Tensor[source]

Computes Aspect-based Diversity.

Reference: Iana, Andreea, Goran Glavaš, and Heiko Paulheim. “Train Once, Use Flexibly: A Modular Framework for Multi-Aspect Neural News Recommendation.” arXiv preprint arXiv:2307.16089 (2023). https://arxiv.org/pdf/2307.16089.pdf

Parameters:
  • preds – Estimated probabilities of each candidate news to be clicked.

  • target – Ground truth about the aspect \(A_p\) of the news being relevant or not.

  • num_classes – Number of classes of the aspect \(A_p\).

  • top_k – Consider only the top k elements for each query (default: None, which considers them all).

Returns:

A single-value tensor with the aspect-based diversity (\(D_{A_p}\)) of the predictions preds wrt the labels target.

newsreclib.metrics.functional.generalized_jaccard(pred: Tensor, target: Tensor) Tensor[source]

Computes the Generalized Jaccard metric.

Reference: Bonnici, Vincenzo. “Kullback-Leibler divergence between quantum distributions, and its upper-bound.” arXiv preprint arXiv:2008.05932 (2020).

Parameters:
  • preds – Estimated probability distribution.

  • target – Target probability distribution.

Returns:

A single-value tensor with the generalized Jaccard of the predictions preds wrt the labels target.

newsreclib.metrics.functional.harmonic_mean(scores: Tensor) Tensor[source]

Computes the harmonic mean of N scores.

Parameters:

scores – A tensor of scores.

Returns:

A single-value tensor with the harmonic mean of the scores.

newsreclib.metrics.functional.personalization(preds: Tensor, predicted_aspects: Tensor, target_aspects: Tensor, num_classes: int, top_k: Optional[int] = None) Tensor[source]

Computes Aspect-based Personalization.

Reference: Iana, Andreea, Goran Glavaš, and Heiko Paulheim. “Train Once, Use Flexibly: A Modular Framework for Multi-Aspect Neural News Recommendation.” arXiv preprint arXiv:2307.16089 (2023). https://arxiv.org/pdf/2307.16089.pdf

Parameters:
  • preds – Estimated probabilities of each candidate news to be clicked.

  • predicted_aspects – Aspects of the news predicted to be clicked.

  • target_aspects – Ground truth about the aspect \(A_p\) of the news being relevant or not.

  • num_classes – Number of classes of the aspect \(A_p\).

  • top_k – Consider only the top k elements for each query (default: None, which considers them all).

Returns:

A single-value tensor with the aspect-based personalization (\(PS_{A_p}\)) of the predictions preds and predicted_aspects wrt the labels target_aspects.

newsreclib.metrics.personalization module
class newsreclib.metrics.personalization.Personalization(num_classes: int, empty_target_action: str = 'neg', ignore_index: Optional[int] = None, top_k: Optional[int] = None, **kwargs: Any)[source]

Bases: CustomRetrievalMetric

Implementation of the Aspect-based Personalization.

Reference: Iana, Andreea, Goran Glavaš, and Heiko Paulheim. “Train Once, Use Flexibly: A Modular Framework for Multi-Aspect Neural News Recommendation.” arXiv preprint arXiv:2307.16089 (2023). https://arxiv.org/pdf/2307.16089.pdf

For further details, please refer to the paper

full_state_update: bool = False
higher_is_better: bool = True
is_differentiable: bool = False
newsreclib.metrics.utils module
newsreclib.metrics.utils.get_metric(metric_name: str, metric_params: Optional[Dict[str, Union[str, int]]])[source]

Returns a metric object for the specified name.

Parameters:
  • metric_name – Name of the metric.

  • metric_params – Dictionary of parameters for instantiating the metric object.

Module contents

newsreclib.models package

Subpackages
newsreclib.models.components package
Subpackages
newsreclib.models.components.encoders package
Subpackages
newsreclib.models.components.encoders.news package
Submodules
newsreclib.models.components.encoders.news.aspect module
class newsreclib.models.components.encoders.news.aspect.SentimentEncoder(num_sent_classes: int, sent_embed_dim: int, sent_output_dim: int)[source]

Bases: Module

Implements the sentiment encoder from SentiDebias.

Reference: Wu, Chuhan, Fangzhao Wu, Tao Qi, Wei-Qiang Zhang, Xing Xie, and Yongfeng Huang. “Removing AI’s sentiment manipulation of personalized news delivery.” Humanities and Social Sciences Communications 9, no. 1 (2022): 1-9.

For further details, please refer to the paper

num_sent_classes

Number of sentiment classes.

sent_embed_dim

Number of features in the sentiment embedding.

sent_output_dim

Number of output features in the linear layer (equivalent to the final dimensionality of the sentiment vector).

forward(sentiment) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.news.category module
class newsreclib.models.components.encoders.news.category.LinearEncoder(pretrained_embeddings: Optional[Tensor], from_pretrained: bool, freeze_pretrained_emb: bool, num_categories: int, embed_dim: Optional[int], use_dropout: bool, dropout_probability: Optional[float], linear_transform: bool, output_dim: Optional[int])[source]

Bases: Module

Implements a category encoder.

pretrained_embeddings

Matrix of pretrained embeddings.

from_pretrained

If True, it initializes the category embedding layer with pretrained embeddings. If False, it initializes the category embedding layer with random weights.

freeze_pretrained_emb

If True, it freezes the pretrained embeddings during training. If False, it updates the pretrained embeddings during training.

num_categories

Number of categories.

embed_dim

Number of features in the category vector.

use_dropout

Whether to use dropout after the embedding layer.

dropout_probability

Dropout probability.

linear_transform

Whether to linearly transform the category vector.

output_dim

Number of output features in the category encoder (equivalent to the final dimensionality of the category vector).

forward(category: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.news.news module
class newsreclib.models.components.encoders.news.news.KCNN(pretrained_text_embeddings: Tensor, pretrained_entity_embeddings: Tensor, pretrained_context_embeddings: Optional[Tensor], use_context: bool, text_embed_dim: int, entity_embed_dim: int, num_filters: int, window_sizes: List[int])[source]

Bases: Module

Implements the knowledge-aware CNN from DKN.

Reference: Wang, Hongwei, Fuzheng Zhang, Xing Xie, and Minyi Guo. “DKN: Deep knowledge-aware network for news recommendation.” In Proceedings of the 2018 world wide web conference, pp. 1835-1844. 2018.

For further details, please refer to the paper

pretrained_text_embeddings

Matrix of pretrained text embeddings.

pretrained_entity_embeddings

Matrix of pretrained entity embeddings.

pretrained_context_embeddings

Matrix of pretrained context embeddings.

use_context

Whether to use context embeddings.

text_embed_dim

The number of features in the text vector.

entity_embed_dim

The number of features in the entity vector.

num_filters

The number of filters in the CNN.

window_sizes

List of window sizes for the CNN.

forward(news: Dict[str, Tensor]) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class newsreclib.models.components.encoders.news.news.NewsEncoder(dataset_attributes: List[str], attributes2encode: List[str], concatenate_inputs: bool, text_encoder: Optional[Module], category_encoder: Optional[Module], entity_encoder: Optional[Module], combine_vectors: bool, combine_type: Optional[str], input_dim: Optional[int], query_dim: Optional[int], output_dim: Optional[int])[source]

Bases: Module

Implements a news encoder.

dataset_attributes

List of news features available in the used dataset.

attributes2encode

List of news features used as input to the news encoder.

concatenate_inputs

Whether the inputs (e.g., title and abstract) were concatenated into a single sequence.

text_encoder

The text encoder module.

category_encoder

The category encoder module.

entity_encoder

The entity encoder module.

combine_vectors

Whether to aggregate the representations of multiple news features.

combine_type

The type of aggregation to use for combining multiple news features representations. Choose between add_att (additive attention), linear, and concat (concatenate).

input_dim

The number of input features in the aggregation layer.

query_dim

The number of features in the query vector.

output_dim

The number of features in the final news vector.

forward(news: Dict[str, Tensor]) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.news.text module
class newsreclib.models.components.encoders.news.text.CNNAddAtt(pretrained_embeddings: Tensor, embed_dim: int, num_filters: int, window_size: int, query_dim: int, dropout_probability: float)[source]

Bases: Module

Implements a text encoder based on CNN and additive attention.

Reference: Wu, Chuhan, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. “Neural news recommendation with attentive multi-view learning.” arXiv preprint arXiv:1907.05576 (2019).

For further details, please refer to the paper

pretrained_embeddings

Matrix of pretrained embeddings.

embed_dim

The number of features in the text vector.

num_filters

The number of filters in the CNN.

window_size

The window size in the CNN.

query_dim

The number of features in the query vector.

dropout_probability

Dropout probability.

forward(text: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class newsreclib.models.components.encoders.news.text.CNNMHSAAddAtt(pretrained_embeddings: Tensor, embed_dim: int, num_filters: int, window_size: int, num_heads: int, query_dim: int, dropout_probability: float)[source]

Bases: Module

Implements a text encoder based on CNN, multi-head self-attention, and additive attention.

Reference: Qi, Tao, Fangzhao Wu, Chuhan Wu, Yongfeng Huang, and Xing Xie. “Privacy-Preserving News Recommendation Model Learning.” In Findings of the Association for Computational Linguistics: EMNLP 2020, pp. 1423-1432. 2020.

For further details, please refer to the paper

pretrained_embeddings

Matrix of pretrained embeddings.

num_filters

The number of filters in the CNN.

window_size

The window size in the CNN.

num_heads

The number of heads in the MultiheadAttention.

query_dim

The number of features in the query vector.

dropout_probability

Dropout probability.

forward(text: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class newsreclib.models.components.encoders.news.text.CNNPersAtt(pretrained_embeddings: Tensor, text_embed_dim: int, user_embed_dim: int, num_filters: int, window_size: int, query_dim: int, dropout_probability: float)[source]

Bases: Module

Implements a text encoder based on CNN and Personalized Attention.

Reference: Wu, Chuhan, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. “NPA: neural news recommendation with personalized attention.” In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 2576-2584. 2019.

For further details, please refer to the paper

pretrained_embeddings

Matrix of pretrained embeddings.

text_embed_dim

The number of features in the text vector.

user_embed_dim

The number of features in the user vector.

num_filters

The number of filters in the CNN.

window_size

The window size in the CNN.

query_dim

The number of features in the query vector.

dropout_probability

Dropout probability.

forward(text: Tensor, lengths: Tensor, projected_users: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class newsreclib.models.components.encoders.news.text.MHSAAddAtt(pretrained_embeddings: Tensor, embed_dim: int, num_heads: int, query_dim: int, dropout_probability: float)[source]

Bases: Module

Implements a text encoder based on multi-head self-attention and additive attention.

Reference: Wu, Chuhan, Fangzhao Wu, Suyu Ge, Tao Qi, Yongfeng Huang, and Xing Xie. “Neural news recommendation with multi-head self-attention.” In Proceedings of the 2019 conference on empirical methods in natural language processing and the 9th international joint conference on natural language processing (EMNLP-IJCNLP), pp. 6389-6394. 2019.

For further details, please refer to the paper

pretrained_embeddings

Matrix of pretrained embeddings.

embed_dim

The number of features in the text vector.

num_heads

The number of heads in the MultiheadAttention.

query_dim

The number of features in the query vector.

dropout_probability

Dropout probability.

forward(text: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class newsreclib.models.components.encoders.news.text.PLM(plm_model, frozen_layers: Optional[List[int]], embed_dim: int, use_mhsa: bool, apply_reduce_dim: bool, reduced_embed_dim: Optional[int], num_heads: Optional[int], query_dim: Optional[int], dropout_probability: float)[source]

Bases: Module

Implements a text encoder based on a pretrained language model.

plm_model

Name of the pretrained language model.

frozen_layers

List of layers to freeze during training.

embed_dim

Number of features in the text vector.

use_mhsa

If True, it aggregates the token embeddings with a multi-head self-attention network into a final text representation. If False, it uses the CLS embedding as the final text representation.

apply_reduce_dim

Whether to linearly reduce the dimensionality of the news vector.

reduced_embed_dim

The number of features in the reduced news vector.

num_heads

The number of heads in the MultiheadAttention.

query_dim

The number of features in the query vector.

dropout_probability

Dropout probability.

forward(text: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
Module contents
newsreclib.models.components.encoders.user package
Submodules
newsreclib.models.components.encoders.user.caum module
class newsreclib.models.components.encoders.user.caum.UserEncoder(news_embed_dim: int, num_filters: int, dense_att_hidden_dim1: int, dense_att_hidden_dim2: int, user_vector_dim: int, num_heads: int, dropout_probability: float)[source]

Bases: Module

Implements the user encoder of CAUM.

Reference: Qi, Tao, Fangzhao Wu, Chuhan Wu, and Yongfeng Huang. “News recommendation with candidate-aware user modeling.” In Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 1917-1921. 2022.

For further details, please refer to the paper

news_embed_dim

The number of features in the news vector.

num_filters

The number of output features in the first linear layer.

dense_att_hidden_dim1

The number of output features in the first hidden state of the DenseAttention.

dense_att_hidden_dim2

The number of output features in the second hidden state of the DenseAttention.

user_vector_dim

The number of features in the user vector.

num_heads

The number of heads in the MultiheadAttention.

dropout_probability

Dropout probability.

forward(hist_news_vector: Tensor, cand_news_vector: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.user.cen_news_rec module
class newsreclib.models.components.encoders.user.cen_news_rec.UserEncoder(num_filters: int, num_heads: int, query_dim: int, gru_hidden_dim: int, num_recent_news: int, dropout_probability: float)[source]

Bases: Module

Implements the user encoder of CenNewsRec.

Reference: Qi, Tao, Fangzhao Wu, Chuhan Wu, Yongfeng Huang, and Xing Xie. “Privacy-Preserving News Recommendation Model Learning.” In Findings of the Association for Computational Linguistics: EMNLP 2020, pp. 1423-1432. 2020.

For further details, please refer to the paper

num_filters

The number of input features in the MultiheadAttention

num_heads

The number of heads in the MultiheadAttention.

query_dim

The number of features in the query vector.

gru_hidden_dim

The number of features in the hidden state of the GRU.

num_recent_news

Number of recent news to be encoded in the short-term user representation.

dropout_probability

Dropout probability.

forward(hist_news_vector: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.user.dkn module
class newsreclib.models.components.encoders.user.dkn.UserEncoder(input_dim: int, hidden_dim: int)[source]

Bases: Module

Implements the user encoder of DKN.

Reference: Wang, Hongwei, Fuzheng Zhang, Xing Xie, and Minyi Guo. “DKN: Deep knowledge-aware network for news recommendation.” In Proceedings of the 2018 world wide web conference, pp. 1835-1844. 2018.

For further details, please refer to the paper

input_dim

The number of input features to the user encoder.

hidden_dim

The number of features in the hidden state of the user encoder.

forward(hist_news_vector: Tensor, cand_news_vector: Tensor, mask_hist: Tensor, mask_cand: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.user.lstur module
class newsreclib.models.components.encoders.user.lstur.UserEncoder(num_users: int, input_dim: int, user_masking_probability: float, long_short_term_method: str)[source]

Bases: Module

Implements the user encoder of LSTUR.

Reference: An, Mingxiao, Fangzhao Wu, Chuhan Wu, Kun Zhang, Zheng Liu, and Xing Xie. “Neural news recommendation with long-and short-term user representations.” In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pp. 336-345. 2019.

For further details, please refer to the paper

num_users

The number of users.

input_dim

The number of input features in the embeddng layer for the long-term user representation.

user_masking_probability

The probability for randomly masking users in the long-term user representation.

long_short_term_method

The method for combining long and short-term user representations. If ini is chosen, the GRU will be initialized with the long-term user representation. If con is chosen, the long and short-term user representations will be concatenated.

forward(user: Tensor, hist_news_vector: Tensor, hist_size: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.user.mins module
class newsreclib.models.components.encoders.user.mins.UserEncoder(news_embed_dim: int, query_dim: int, num_filters: int, num_gru_channels: int)[source]

Bases: Module

Implements the user encoder of MINS.

Reference: Wang, Rongyao, Shoujin Wang, Wenpeng Lu, and Xueping Peng. “News recommendation via multi-interest news sequence modelling.” In ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 7942-7946. IEEE, 2022.

For further details, please refer to the paper

news_embed_dim

The number of features in the news vector.

query_dim

The number of features in the query vector.

num_filters

The number of filters used in the GRU.

num_gru_channels

The number of channels used in the GRU.

forward(hist_news_vector: Tensor, hist_size: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.user.naml module
class newsreclib.models.components.encoders.user.naml.UserEncoder(news_embed_dim: int, query_dim: int)[source]

Bases: Module

Implements the user encoder of NAML.

Reference: Wu, Chuhan, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. “Neural news recommendation with attentive multi-view learning.” arXiv preprint arXiv:1907.05576 (2019).

For further details, please refer to the paper

news_embed_dim

The number of features in the user vector.

query_dim

The number of features in the query vector.

forward(hist_news_vector: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.user.npa module
class newsreclib.models.components.encoders.user.npa.UserEncoder(user_embed_dim: int, num_filters: int, preference_query_dim: int, dropout_probability: float)[source]

Bases: Module

Implements the user encoder of NPA.

Reference: Wu, Chuhan, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. “NPA: neural news recommendation with personalized attention.” In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 2576-2584. 2019.

For further details, please refer to the paper

user_embed_dim

The number of feature in the user vector.

num_filters

The number of filters in the PersonalizedAttention.

preference_query_dim

The number of features in the preference query vector.

dropout_probability

Dropout probability.

forward(hist_news_vector: Tensor, projected_users: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.encoders.user.nrms module
class newsreclib.models.components.encoders.user.nrms.UserEncoder(news_embed_dim: int, num_heads: int, query_dim: int)[source]

Bases: Module

Implements the user encoder of NRMS.

Reference: Wu, Chuhan, Fangzhao Wu, Suyu Ge, Tao Qi, Yongfeng Huang, and Xing Xie. “Neural news recommendation with multi-head self-attention.” In Proceedings of the 2019 conference on empirical methods in natural language processing and the 9th international joint conference on natural language processing (EMNLP-IJCNLP), pp. 6389-6394. 2019.

For further details, please refer to the paper

news_embed_dim

The number of features in the news vector.

num_heads

The number of heads in the MultiheadAttention.

query_dim

The number of features in the query vector.

forward(hist_news_vector: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
Module contents
newsreclib.models.components.layers package
Submodules
newsreclib.models.components.layers.attention module
class newsreclib.models.components.layers.attention.AdditiveAttention(input_dim: int, query_dim: int)[source]

Bases: Module

forward(input_vector: Tensor) Tensor[source]
Parameters:

input_vector – User tensor of shape (batch_size, hidden_dim, output_dim).

Returns:

User tensor of shape (batch_size, news_emb_dim).

training: bool
class newsreclib.models.components.layers.attention.DenseAttention(input_dim: int, hidden_dim1: int, hidden_dim2: int)[source]

Bases: Module

Dense attention used in CAUM.

Reference: Qi, Tao, Fangzhao Wu, Chuhan Wu, and Yongfeng Huang. “News recommendation with candidate-aware user modeling.” In Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 1917-1921. 2022.

For further details, please refer to the paper

forward(input_vector: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class newsreclib.models.components.layers.attention.PersonalizedAttention(preference_query_dim: int, num_filters: int)[source]

Bases: Module

Personalized attention used in NPA.

Reference: Wu, Chuhan, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. “NPA: neural news recommendation with personalized attention.” In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 2576-2584. 2019.

For further details, please refer to the paper

forward(query: Tensor, keys: Tensor) Tensor[source]
Parameters:
  • query(batch_size * preference_dim)

  • keys(batch_size * num_filters * num_words_text)

Returns:

(batch_size * num_filters)

training: bool
class newsreclib.models.components.layers.attention.PolyAttention(input_dim: int, num_context_codes: int, context_code_dim: int)[source]

Bases: Module

Implementation of Poly attention scheme (used in MINER) that extracts K attention vectors through K additive attentions.

Adapted from https://github.com/duynguyen-0203/miner/blob/master/src/model/model.py.

Reference: Li, Jian, Jieming Zhu, Qiwei Bi, Guohao Cai, Lifeng Shang, Zhenhua Dong, Xin Jiang, and Qun Liu. “MINER: multi-interest matching network for news recommendation.” In Findings of the Association for Computational Linguistics: ACL 2022, pp. 343-352. 2022.

For further details, please refer to the paper

forward(embeddings: Tensor, attn_mask: Tensor, bias: Optional[Tensor] = None)[source]
Parameters:
  • embeddings(batch_size, hist_length, embed_dim)

  • attn_mask(batch_size, hist_length)

  • bias(batch_size, hist_length, num_candidates)

Returns:

(batch_size, num_context_codes, embed_dim)

Return type:

torch.Tensor

training: bool
class newsreclib.models.components.layers.attention.TargetAwareAttention(input_dim: int)[source]

Bases: Module

Implementation of the target-aware attention network used in MINER.

Adapted from https://github.com/duynguyen-0203/miner/blob/master/src/model/model.py

Reference: Li, Jian, Jieming Zhu, Qiwei Bi, Guohao Cai, Lifeng Shang, Zhenhua Dong, Xin Jiang, and Qun Liu. “MINER: multi-interest matching network for news recommendation.” In Findings of the Association for Computational Linguistics: ACL 2022, pp. 343-352. 2022.

For further details, please refer to the paper

forward(query: Tensor, key: Tensor, value: Tensor) Tensor[source]
Parameters:
  • query(batch_size, num_context_codes, input_embed_dim)

  • key(batch_size, num_candidates, input_embed_dim)

  • value(batch_size, num_candidates, num_context_codes)

training: bool
newsreclib.models.components.layers.click_predictor module
class newsreclib.models.components.layers.click_predictor.DNNPredictor(input_dim: int, hidden_dim: int)[source]

Bases: Module

Implementation of the click pedictor of DKN.

Reference: Wang, Hongwei, Fuzheng Zhang, Xing Xie, and Minyi Guo. “DKN: Deep knowledge-aware network for news recommendation.” In Proceedings of the 2018 world wide web conference, pp. 1835-1844. 2018.

For further details, please refer to the paper

forward(user_vec: Tensor, cand_news: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class newsreclib.models.components.layers.click_predictor.DotProduct[source]

Bases: Module

forward(user_vec: Tensor, cand_news_vector: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
newsreclib.models.components.layers.projection module
class newsreclib.models.components.layers.projection.UserPreferenceQueryProjection(user_embed_dim: int, preference_query_dim: int, dropout_probability: float)[source]

Bases: Module

Projects dense user representations to preference query vector.

Reference: Wu, Chuhan, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. “NPA: neural news recommendation with personalized attention.” In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 2576-2584. 2019.

For further details, please refer to the paper

forward(projected_users: Tensor) Tensor[source]
Parameters:

projected_user – Vector of project users of size (batch_size * embedding_dim)

Returns:

Project query vector of size (batch_size * preference_dim)

training: bool
class newsreclib.models.components.layers.projection.UserProjection(num_users: int, user_embed_dim: int, dropout_probability: float)[source]

Bases: Module

Embeds user ID to dense vector through a lookup table.

Reference: Wu, Chuhan, Fangzhao Wu, Mingxiao An, Jianqiang Huang, Yongfeng Huang, and Xing Xie. “NPA: neural news recommendation with personalized attention.” In Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 2576-2584. 2019.

For further details, please refer to the paper

forward(users: Tensor) Tensor[source]
Parameters:

users – Vector of users of size batch_size

Returns:

Projected users vector of size ‘(batch_size * user_embedding_dim)`

training: bool
Module contents
Submodules
newsreclib.models.components.losses module
class newsreclib.models.components.losses.SupConLoss(temperature=0.1, **kwargs)[source]

Bases: SupConLoss

compute_loss(embeddings, labels, indices_tuple, ref_emb, ref_labels)[source]

This has to be implemented and is what actually computes the loss.

training: bool
newsreclib.models.components.utils module
newsreclib.models.components.utils.pairwise_cosine_similarity(x: Tensor, y: Tensor, zero_diagonal: bool = False) Tensor[source]

Implemented from https://github.com/duynguyen-0203/miner/blob/master/src/utils.py

Calculates the pairwise cosine similarity matrix.

Parameters:
  • x – (batch_size, M, d)

  • y – (batch_size, N, d)

  • zero_diagonal – Determines if the diagonal of the distance matrix should be set to zero.

Returns:

A single-value tensor with the pairwise cosine similarity between x and y.

Module contents
newsreclib.models.fair_rec package
Submodules
newsreclib.models.fair_rec.manner_a_module module
newsreclib.models.fair_rec.manner_cr_module module
newsreclib.models.fair_rec.manner_module module
newsreclib.models.fair_rec.senti_debias_module module
newsreclib.models.fair_rec.sentirec module
newsreclib.models.general_rec package
Submodules
newsreclib.models.general_rec.caum_module module
newsreclib.models.general_rec.cen_news_rec_module module
newsreclib.models.general_rec.dkn_module module
newsreclib.models.general_rec.lstur_module module
newsreclib.models.general_rec.miner_module module
newsreclib.models.general_rec.mins_module module
newsreclib.models.general_rec.naml_module module
newsreclib.models.general_rec.npa_module module
newsreclib.models.general_rec.nrms_module module
newsreclib.models.general_rec.tanr_module module
Submodules
newsreclib.models.abstract_recommender module
Module contents

newsreclib.utils package

Submodules
newsreclib.utils.instantiators module
newsreclib.utils.logging_utils module
newsreclib.utils.pylogger module
newsreclib.utils.rich_utils module
newsreclib.utils.utils module
Module contents

Submodules

newsreclib.train module

newsreclib.eval module

Module contents

Indices and tables