Source code for rsdiv.diversity.mmr

from typing import Optional, Sequence

import numpy as np
import numpy.ma as ma

from .base import BaseReranker


[docs]class MaximalMarginalRelevance(BaseReranker): """Improve the diversity with Maximal Marginal Relevance algorithm.""" def __init__(self, lbd: float): assert 0 <= lbd <= 1, "lbd should be within the interval [0, 1]!" self.lbd = lbd def rerank( self, quality_scores: np.ndarray, k: int, *, similarity_scores: np.ndarray, embeddings: Optional[np.ndarray] = None, ) -> Sequence[int]: assert k > 0, "k must be larger than 0!" n = quality_scores.shape[0] k = min(k, n) new_selection = np.argmax(quality_scores).item() selected_ind = [new_selection] ma_similarity_scores = ma.array(similarity_scores, mask=True) ma_similarity_scores.mask[:, new_selection] = False ma_similarity_scores[new_selection, new_selection] = ma.masked quality_scores = ma.array(quality_scores) quality_scores[new_selection] = ma.masked for _ in range(k - 1): scores = self.lbd * quality_scores - (1.0 - self.lbd) * np.max( ma_similarity_scores, axis=1 ) new_selection = np.argmax(scores).item() quality_scores[new_selection] = ma.masked ma_similarity_scores.mask[:, new_selection] = False ma_similarity_scores[new_selection, :] = ma.masked ma_similarity_scores[selected_ind, new_selection] = ma.masked selected_ind.append(new_selection) return selected_ind