Evaluation Metrics
CLAX supports evaluation metrics for click and relevance prediction. For click prediction, CLAX implements log-likelihood, and conditional and unconditional perplexity. To evaluate ranking performance, CLAX supports ranking metrics such as nDCG or MRR from the RAX library.
CLAX follows the design of FLAX NNX metrics, which updates multiple metrics at once. Each metric picks the respective input parameters it requires and updates its internal state:
from clax.metrics import (
MultiMetric,
LogLikelihood,
Perplexity,
ConditionalPerplexity
)
metrics = MultiMetric(
**{
"ll": LogLikelihood(),
"ppl": Perplexity(),
"cond_ppl": ConditionalPerplexity(),
}
)
metrics.update(
log_probs=log_probs,
conditional_log_probs=cond_log_probs,
clicks=clicks,
where=mask,
)
results = metrics.compute()
rank_results = metrics.compute_per_rank()
Finally, you can compute the metric value by calling metric.compute()
or metric.compute_per_rank()
if you want to compute the mean metric value per positions.
Click Metrics
Log-likelihood
The most common metric for click prediction is the log-likelihood, measuring how well a model fits observed clicks:
where \(\hat{c} = P(C = 1 \mid d, k, C_{<k})\) are a model's click predictions for a document \(d\) at rank \(k\), conditioned on clicks observed before the current rank \(C_{<k}\). Log-likelihood values are negative, with higher values (closer to zero) indicating better model fit.
Bases: RankBasedAverage
Examples:
Compute the mean log-likelihood over a single query with three documents
from conditional log probabilities. Use where = False
to mask out padding documents:
ll = LogLikelihood()
ll.update(
cond_log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
clicks=jnp.array([[1, 0, 1]]),
where=jnp.array([[True, True, True]]),
)
ll.compute()
Compute the mean log-likelihood for each rank:
ll.compute_per_rank()
Source code in clax/metrics.py
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
|
Conditional Perplexity
Perplexity can offer a more intuitive interpretation than log-likelihood. It measures how surprised a model is by the observed data, with a lower value indicating a better model fit. Intuitively, it represents the weighted average number of choices a model is considering. Perfect predictions yield a perplexity of \(1\), while random guessing for binary outcomes gives a perplexity of \(2\), as the model is as uncertain as a coin flip. Conditional perplexity is defined as:
where \(\hat{c} = P(C=1 \mid d, k, C_{<k})\) are a model's click predictions for a document \(d\) at rank \(k\), conditioned on clicks observed before the current rank \(C_{<k}\). Note that models which adopt their behavior based on clicks in the current search session might score better in conditional predictions.
Bases: RankBasedAverage
Examples:
Compute the mean conditional perplexity over a single query with three documents
from conditional log probabilities. Use where = False
to mask out padding documents:
cond_ppl = ConditionalPerplexity()
cond_ppl.update(
cond_log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
clicks=jnp.array([[1, 0, 1]]),
where=jnp.array([[True, True, True]]),
)
cond_ppl.compute()
Compute the mean conditional perplexity for each rank:
cond_ppl.compute_per_rank()
Source code in clax/metrics.py
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
|
Perplexity
Similar to conditional perplexity, unconditional perplexity measures how surprised a model is by the observed data, with a lower value indicating a better model fit. However, in contrast to conditional perplexity and log-likelihood, unconditional perplexity is calculated from click predictions that do not take clicks from the current user session into account. Unconditional perplexity is defined as:
where \(\hat{c} = P(C=1 \mid d, k)\) are a model's unconditional click predictions for a document \(d\) at rank \(k\). Meaning, a model has to predict all clicks for the current session without access to any clicks on earlier ranks in the session.
Bases: RankBasedAverage
Examples:
Compute the mean (unconditional) perplexity over a single query with three documents
from unconditional log probabilities. Use where = False
to mask out padding documents:
ppl = Perplexity()
ppl.update(
log_probs=jnp.array([[-0.01, -10.0, -0.7]]),
clicks=jnp.array([[1, 0, 1]]),
where=jnp.array([[True, True, True]]),
)
ppl.compute()
Compute the mean perplexity for each rank:
ppl.compute_per_rank()
Source code in clax/metrics.py
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
|
Ranking Metrics
Instead of re-implementing ranking metrics in CLAX, we opted to integrate with RAX the most popular ranking library in JAX. You can use any ranking metric from RAX by wrapping it in a RaxMetric
object. Below, score
is the relevance prediction of a click model and label
is typically an expert-annotated relevance label:
import rax
from clax.metrics import RaxMetric, MultiMetric
metrics = MultiMetric(
**{
"dcg@10": RaxMetric(rax.dcg_metric, top_n=10),
"mrr@10": RaxMetric(rax.mrr_metric, top_n=10),
}
)
metrics.update(scores=scores, labels=labels, where=mask)
metrics.compute()