Skip to content

Evaluation Metrics

CLAX supports evaluation metrics for click and relevance prediction. For click prediction, CLAX implements log-likelihood, and conditional and unconditional perplexity. To evaluate ranking performance, CLAX supports ranking metrics such as nDCG or MRR from the RAX library.

CLAX follows the design of FLAX NNX metrics, which updates multiple metrics at once. Each metric picks the respective input parameters it requires and updates its internal state:

from clax.metrics import (
    MultiMetric,
    LogLikelihood,
    Perplexity,
    ConditionalPerplexity
)

metrics = MultiMetric(
    **{
        "ll": LogLikelihood(),
        "ppl": Perplexity(),
        "cond_ppl": ConditionalPerplexity(),
    }
)

metrics.update(
    log_probs=log_probs,
    conditional_log_probs=cond_log_probs,
    clicks=clicks,
    where=mask,
)

results = metrics.compute()
rank_results = metrics.compute_per_rank()

Finally, you can compute the metric value by calling metric.compute() or metric.compute_per_rank() if you want to compute the mean metric value per positions.

Click Metrics

Log-likelihood

The most common metric for click prediction is the log-likelihood, measuring how well a model fits observed clicks:

\[ \operatorname{LL}(\mathcal{D}) = \frac{1}{|\mathcal{D}|} \sum_{(d, k, c) \in \mathcal{D}} \Big[ c \log \hat{c} + (1 - c) \log \left(1 - \hat{c} \right) \Big], \]

where \(\hat{c} = P(C = 1 \mid d, k, C_{<k})\) are a model's click predictions for a document \(d\) at rank \(k\), conditioned on clicks observed before the current rank \(C_{<k}\). Log-likelihood values are negative, with higher values (closer to zero) indicating better model fit.

Bases: RankBasedAverage

Examples:

Compute the mean log-likelihood over a single query with three documents from conditional log probabilities. Use where = False to mask out padding documents:

ll = LogLikelihood()
ll.update(
    cond_log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
    clicks=jnp.array([[1, 0, 1]]),
    where=jnp.array([[True, True, True]]),
)
ll.compute()

Compute the mean log-likelihood for each rank:

ll.compute_per_rank()
Source code in clax/metrics.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class LogLikelihood(RankBasedAverage):
    """
    Examples:
        Compute the mean log-likelihood over a single query with three documents
        from conditional log probabilities. Use `where = False` to mask out padding documents:

            ll = LogLikelihood()
            ll.update(
                cond_log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
                clicks=jnp.array([[1, 0, 1]]),
                where=jnp.array([[True, True, True]]),
            )
            ll.compute()

        Compute the mean log-likelihood for each rank:

            ll.compute_per_rank()
    """

    def update(
        self,
        *,
        cond_log_probs: Array,
        clicks: Array,
        where: Optional[Array] = None,
        **kwargs,
    ):
        p_click = cond_log_probs
        p_no_click = log1mexp(cond_log_probs)
        log_likelihood = clicks * p_click + (1 - clicks) * p_no_click

        super().update_values(log_likelihood, where=where)

Conditional Perplexity

Perplexity can offer a more intuitive interpretation than log-likelihood. It measures how surprised a model is by the observed data, with a lower value indicating a better model fit. Intuitively, it represents the weighted average number of choices a model is considering. Perfect predictions yield a perplexity of \(1\), while random guessing for binary outcomes gives a perplexity of \(2\), as the model is as uncertain as a coin flip. Conditional perplexity is defined as:

\[ \operatorname{PPL}(\mathcal{D}) = 2^{- \frac{1}{|\mathcal{D}|} \sum_{(d, k, c) \in \mathcal{D}} \Big[ c \log_2 \hat{c} + (1 - c) \log_2 \left(1 - \hat{c} \right) \Big]}, \]

where \(\hat{c} = P(C=1 \mid d, k, C_{<k})\) are a model's click predictions for a document \(d\) at rank \(k\), conditioned on clicks observed before the current rank \(C_{<k}\). Note that models which adopt their behavior based on clicks in the current search session might score better in conditional predictions.

Bases: RankBasedAverage

Examples:

Compute the mean conditional perplexity over a single query with three documents from conditional log probabilities. Use where = False to mask out padding documents:

cond_ppl = ConditionalPerplexity()
cond_ppl.update(
    cond_log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
    clicks=jnp.array([[1, 0, 1]]),
    where=jnp.array([[True, True, True]]),
)
cond_ppl.compute()

Compute the mean conditional perplexity for each rank:

cond_ppl.compute_per_rank()
Source code in clax/metrics.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class ConditionalPerplexity(RankBasedAverage):
    """
    Examples:
        Compute the mean conditional perplexity over a single query with three documents
        from conditional log probabilities. Use `where = False` to mask out padding documents:

            cond_ppl = ConditionalPerplexity()
            cond_ppl.update(
                cond_log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
                clicks=jnp.array([[1, 0, 1]]),
                where=jnp.array([[True, True, True]]),
            )
            cond_ppl.compute()

        Compute the mean conditional perplexity for each rank:

            cond_ppl.compute_per_rank()
    """

    def update(
        self,
        *,
        cond_log_probs: Array,
        clicks: Array,
        where: Optional[Array] = None,
        **kwargs,
    ):
        # Convert log probabilities ln(p) to log_2(p)
        p_click = cond_log_probs / jnp.log(2)
        p_no_click = log1mexp(cond_log_probs) / jnp.log(2)
        log_likelihood = clicks * p_click + (1 - clicks) * p_no_click

        super().update_values(log_likelihood, where=where)

    def compute(self):
        # Avg. cond. perplexity is the mean over ranks:
        return self.compute_per_rank().mean()

    def compute_per_rank(self):
        return 2 ** -super().compute_per_rank()

Perplexity

Similar to conditional perplexity, unconditional perplexity measures how surprised a model is by the observed data, with a lower value indicating a better model fit. However, in contrast to conditional perplexity and log-likelihood, unconditional perplexity is calculated from click predictions that do not take clicks from the current user session into account. Unconditional perplexity is defined as:

\[ \operatorname{PPL}(\mathcal{D}) = 2^{- \frac{1}{|\mathcal{D}|} \sum_{(d, k, c) \in \mathcal{D}} \Big[ c \log_2 \hat{c} + (1 - c) \log_2 \left(1 - \hat{c} \right) \Big]}, \]

where \(\hat{c} = P(C=1 \mid d, k)\) are a model's unconditional click predictions for a document \(d\) at rank \(k\). Meaning, a model has to predict all clicks for the current session without access to any clicks on earlier ranks in the session.

Bases: RankBasedAverage

Examples:

Compute the mean (unconditional) perplexity over a single query with three documents from unconditional log probabilities. Use where = False to mask out padding documents:

ppl = Perplexity()
ppl.update(
    log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
    clicks=jnp.array([[1, 0, 1]]),
    where=jnp.array([[True, True, True]]),
)
ppl.compute()

Compute the mean perplexity for each rank:

ppl.compute_per_rank()
Source code in clax/metrics.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
class Perplexity(RankBasedAverage):
    """
    Examples:
        Compute the mean (unconditional) perplexity over a single query with three documents
        from unconditional log probabilities. Use `where = False` to mask out padding documents:

            ppl = Perplexity()
            ppl.update(
                log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
                clicks=jnp.array([[1, 0, 1]]),
                where=jnp.array([[True, True, True]]),
            )
            ppl.compute()

        Compute the mean perplexity for each rank:

            ppl.compute_per_rank()
    """

    def update(
        self,
        *,
        log_probs: Array,
        clicks: Array,
        where: Optional[Array] = None,
        **kwargs,
    ):
        # Convert log probabilities ln(p) to log_2(p)
        p_click = log_probs / jnp.log(2)
        p_no_click = log1mexp(log_probs) / jnp.log(2)
        log_likelihood = clicks * p_click + (1 - clicks) * p_no_click

        super().update_values(log_likelihood, where=where)

    def compute(self):
        # Avg. perplexity is the mean over ranks:
        return self.compute_per_rank().mean()

    def compute_per_rank(self):
        return 2 ** -super().compute_per_rank()

Ranking Metrics

Instead of re-implementing ranking metrics in CLAX, we opted to integrate with RAX the most popular ranking library in JAX. You can use any ranking metric from RAX by wrapping it in a RaxMetric object. Below, score is the relevance prediction of a click model and label is typically an expert-annotated relevance label:

import rax
from clax.metrics import RaxMetric, MultiMetric

metrics = MultiMetric(
    **{
        "dcg@10": RaxMetric(rax.dcg_metric, top_n=10),
        "mrr@10": RaxMetric(rax.mrr_metric, top_n=10),
    }
)

metrics.update(scores=scores, labels=labels, where=mask)
metrics.compute()