Skip to content

Overview of Click Models

Cascade Model (CM)

Bases: ClickModel

The Cascade Model (CM) assumes that users scan results from top to bottom, click on the first attractive document they find, and then stop their search.

Parameters:

Name Type Description Default
query_doc_pairs Optional[int]

Number of query-document pairs to allocate in an embedding table. This parameter is not used if a custom attraction module is provided.

None
attraction Optional[Parameter | ParameterConfig]

Custom attraction parameter, which can be a parameter config or any subclass of the Parameter base class.

None
rngs Rngs

A NNX random number generator used for model initialization and stochastic operations.

required

Examples:

Creating a basic Cascade Model:

model = CascadeModel(
    query_doc_pairs=1_000_000,
    rngs=nnx.Rngs(42),
)

Configure a deep network to user custom query-doc-features:

attraction = DeepParameterConfig(
    use_feature="query_doc_features",
    features=16,
    layers=2,
    dropout=0.25,
)
model = CascadeModel(
    attraction=attraction,
    rngs=nnx.Rngs(42),
)
References

Nick Craswell, Onno Zoeter, Michael Taylor, and Bill Ramsey. "An experimental comparison of click position-bias models." In WSDM 2008.

Source code in clax/models/cm.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class CascadeModel(ClickModel):
    """
    The Cascade Model (CM) assumes that users scan results from top to bottom,
    click on the first attractive document they find, and then stop their search.

    Args:
        query_doc_pairs (Optional[int], optional): Number of query-document
            pairs to allocate in an embedding table. This parameter is not
            used if a custom attraction module is provided.
        attraction (Optional[Parameter | ParameterConfig], optional): Custom
            attraction parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        rngs (nnx.Rngs): A NNX random number generator used for model
            initialization and stochastic operations.

    Examples:
        Creating a basic Cascade Model:

            model = CascadeModel(
                query_doc_pairs=1_000_000,
                rngs=nnx.Rngs(42),
            )

        Configure a deep network to user custom query-doc-features:

            attraction = DeepParameterConfig(
                use_feature="query_doc_features",
                features=16,
                layers=2,
                dropout=0.25,
            )
            model = CascadeModel(
                attraction=attraction,
                rngs=nnx.Rngs(42),
            )

    References:
        Nick Craswell, Onno Zoeter, Michael Taylor, and Bill Ramsey.
        "An experimental comparison of click position-bias models."
        In WSDM 2008.
    """

    name = "CM"

    def __init__(
        self,
        query_doc_pairs: Optional[int] = None,
        attraction: Optional[Parameter | ParameterConfig] = None,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.attraction = init_parameter(
            "attraction",
            attraction,
            default_config_fn=default_attraction_config,
            default_config_args={"query_doc_pairs": query_doc_pairs},
            rngs=rngs,
        )

    def compute_loss(self, batch: Dict, aggregate: bool = True):
        y_true = batch["clicks"]
        y_predict = self.predict_conditional_clicks(batch)

        return binary_cross_entropy(
            y_predict,
            y_true,
            where=batch["mask"],
            log_probs=True,
            aggregate=aggregate,
        )

    def predict_conditional_clicks(self, batch: Dict) -> Array:
        click_log_probs = self.predict_clicks(batch)

        # Discard clicks after the first click by setting them to a minimum log prob:
        no_clicks_before = self._no_clicks_before(batch["clicks"])
        click_log_probs = jnp.where(no_clicks_before, click_log_probs, jnp.log(1e-8))

        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_clicks(self, batch: Dict) -> Array:
        attr_logits = self.attraction.logit(batch)

        # Compute log probabilities for relevance and non-relevance:
        attr_log_probs = logits_to_log_probs(attr_logits)
        non_attr_log_probs = logits_to_complement_log_probs(attr_logits)

        # Compute log examination, the first item is always examined:
        exam_log_probs = jnp.roll(non_attr_log_probs, shift=1, axis=-1)
        exam_log_probs = exam_log_probs.at[:, 0].set(0)
        exam_log_probs = jnp.cumsum(exam_log_probs, axis=-1)

        click_log_probs = exam_log_probs + attr_log_probs
        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def sample(self, batch: Dict, rngs: nnx.Rngs) -> Array:
        attr_probs = self.attraction.prob(batch)
        attraction = batch["mask"] & jax.random.bernoulli(rngs(), attr_probs)

        examination = self._no_clicks_before(attraction)
        clicks = examination & attraction

        return CascadeModelOutput(
            clicks=clicks,
            examination=examination,
            attraction=attraction,
        )

    def predict_relevance(self, batch: Dict) -> Array:
        return self.attraction.log_prob(batch)

    @staticmethod
    def _no_clicks_before(clicks):
        """
        Check if there are no clicks before each position.
        """
        clicks_before = jnp.cumsum(clicks, axis=-1) - clicks
        return clicks_before == 0

Unconditional click probability

The probability of a click at rank \(k\) depends on the displayed document \(d\) being attractive \(\gamma_d\) and all preceding documents being unattractive:

\[\log P(C=1 \mid d, k) = \log \gamma_d + \sum_{i=1}^{k-1} \log(1 - \gamma_{d_i}).\]

Conditional click probability

The Cascade Model can only explain a single click per list. All other documents after the first click, by definition, have a click probability of \(0\). To avoid a log-likelihood of \(-\infty\) in the conditional click predictions, we assign a very small default click probability to all documents following a click:

\[ \log P(C=1 \mid d, k, C_{<k}) = \begin{cases} \log \gamma_d & \text{if } \sum_{i=1}^{k-1} c_i = 0 \\ \text{min_log_prob} & \text{otherwise}. \end{cases} \]

Position-based Model (PBM)

Bases: ClickModel

The Position-based model (PBM) assumes that users examine ranks independently and only click on examined and attractive results.

Parameters:

Name Type Description Default
positions Optional[int]

Number position embeddings to allocate. This parameter is not used if a custom examination module is provided.

None
query_doc_pairs Optional[int]

Number of query-document pairs to allocate in an embedding table. This parameter is not used if a custom attraction module is provided.

None
examination Optional[Parameter | ParameterConfig]

Custom examination/bias parameter, which can be a parameter config or any subclass of the Parameter base class.

None
attraction Optional[Parameter | ParameterConfig]

Custom attraction/relevance parameter, which can be a parameter config or any subclass of the Parameter base class.

None
rngs Rngs

A NNX random number generator used for model initialization and stochastic operations.

required

Examples:

Creating a basic Position-based Model:

model = PositionBasedModel(
    positions=10,
    query_doc_pairs=1_000_000,
    rngs=nnx.Rngs(42)
)

Configure a two-tower model with a linear combination of bias features and a deep-cross network for document attraction:

model = PositionBasedModel(
    examination=LinearParameterConfig(
        use_feature="bias_features",
        features=8,
    ),
    attraction=DeepCrossParameterConfig(
        use_feature="query_doc_features",
        features=136,
        cross_layers=2,
        deep_layers=2,
        combination=Combination.STACKED,
    ),
    rngs=nnx.Rngs(42),
)
References

Matthew Richardson, Ewa Dominowska, and Robert Ragno. "Predicting Clicks: Estimating the Click-through Rate for New Ads." In WWW 2007.

Nick Craswell, Onno Zoeter, Michael Taylor, and Bill Ramsey. "An experimental comparison of click position-bias models." In WSDM 2008.

Source code in clax/models/pbm.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class PositionBasedModel(ClickModel):
    """
    The Position-based model (PBM) assumes that users examine ranks independently
    and only click on examined and attractive results.

    Args:
        positions (Optional[int], optional): Number position embeddings to allocate.
            This parameter is not used if a custom examination module is provided.
        query_doc_pairs (Optional[int], optional): Number of query-document
            pairs to allocate in an embedding table. This parameter is not
            used if a custom attraction module is provided.
        examination (Optional[Parameter | ParameterConfig], optional): Custom
            examination/bias parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        attraction (Optional[Parameter | ParameterConfig], optional): Custom
            attraction/relevance parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        rngs (nnx.Rngs): A NNX random number generator used for model
            initialization and stochastic operations.

    Examples:
        Creating a basic Position-based Model:

            model = PositionBasedModel(
                positions=10,
                query_doc_pairs=1_000_000,
                rngs=nnx.Rngs(42)
            )

        Configure a two-tower model with a linear combination of bias features and a
        deep-cross network for document attraction:

            model = PositionBasedModel(
                examination=LinearParameterConfig(
                    use_feature="bias_features",
                    features=8,
                ),
                attraction=DeepCrossParameterConfig(
                    use_feature="query_doc_features",
                    features=136,
                    cross_layers=2,
                    deep_layers=2,
                    combination=Combination.STACKED,
                ),
                rngs=nnx.Rngs(42),
            )

    References:
        Matthew Richardson, Ewa Dominowska, and Robert Ragno.
        "Predicting Clicks: Estimating the Click-through Rate for New Ads."
        In WWW 2007.

        Nick Craswell, Onno Zoeter, Michael Taylor, and Bill Ramsey.
        "An experimental comparison of click position-bias models."
        In WSDM 2008.
    """

    name = "PBM"

    def __init__(
        self,
        positions: Optional[int] = None,
        query_doc_pairs: Optional[int] = None,
        examination: Optional[Parameter | ParameterConfig] = None,
        attraction: Optional[Parameter | ParameterConfig] = None,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.examination = init_parameter(
            "examination",
            examination,
            default_config_fn=default_examination_config,
            default_config_args={"positions": positions},
            rngs=rngs,
        )
        self.attraction = init_parameter(
            "attraction",
            attraction,
            default_config_fn=default_attraction_config,
            default_config_args={"query_doc_pairs": query_doc_pairs},
            rngs=rngs,
        )

    def compute_loss(self, batch: Dict, aggregate: bool = True):
        y_true = batch["clicks"]
        y_predict = self.predict_conditional_clicks(batch)

        return binary_cross_entropy(
            y_predict,
            y_true,
            where=batch["mask"],
            log_probs=True,
            aggregate=aggregate,
        )

    def predict_conditional_clicks(self, batch: Dict) -> Array:
        exam_log_probs = self.examination.log_prob(batch)
        attr_log_probs = self.attraction.log_prob(batch)
        click_log_probs = exam_log_probs + attr_log_probs

        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_clicks(self, batch: Dict) -> Array:
        return self.predict_conditional_clicks(batch)

    def predict_relevance(self, batch: Dict) -> Array:
        return self.attraction.log_prob(batch)

    def sample(self, batch: Dict, rngs: nnx.Rngs) -> PositionBasedModelOutput:
        exam_probs = self.examination.prob(batch)
        attr_probs = self.attraction.prob(batch)

        examination = batch["mask"] & jax.random.bernoulli(rngs(), p=exam_probs)
        attraction = batch["mask"] & jax.random.bernoulli(rngs(), p=attr_probs)
        clicks = examination & attraction

        return PositionBasedModelOutput(
            clicks=clicks,
            examination=examination,
            attraction=attraction,
        )

(Un)conditional click probability

The PBM assumes that clicks occurs only if a user first examines the result at rank k with probability \(\theta_k\) and if the displayed document \(d\) is attractive \(\gamma_d\):

\[\log P(C = 1 \mid d, k) = \log \theta_k + \log \gamma_{d}.\]

User Browsing Model (UBM)

Bases: ClickModel

The UBM extends the PBM by making the examination probability depend on both the current position and the position of the last clicked document.

Parameters:

Name Type Description Default
positions Optional[int]

Number positions used to allocate a 2D embedding table of shape (positions, positions). This parameter is not used if a custom examination module is provided.

required
query_doc_pairs Optional[int]

Number of query-document pairs to allocate in an embedding table. This parameter is not used if a custom attraction module is provided.

None
examination Optional[Parameter | ParameterConfig]

Custom examination/bias parameter, which can be a parameter config or any subclass of the Parameter base class.

None
attraction Optional[Parameter | ParameterConfig]

Custom attraction/relevance parameter, which can be a parameter config or any subclass of the Parameter base class.

None
rngs Rngs

A NNX random number generator used for model initialization and stochastic operations.

required

Examples:

model = UserBrowsingModel(
    positions=10,
    query_doc_pairs=1_000_000,
    rngs=nnx.Rngs(42)
)
References

Georges E. Dupret and Benjamin Piwowarski. "A user browsing model to predict search engine click data from past observations." In SIGIR 2008.

Source code in clax/models/ubm.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
class UserBrowsingModel(ClickModel):
    """
    The UBM extends the PBM by making the examination probability depend on both the
    current position and the position of the last clicked document.

    Args:
        positions (Optional[int], optional): Number positions used to allocate a 2D
            embedding table of shape (positions, positions).
            This parameter is not used if a custom examination module is provided.
        query_doc_pairs (Optional[int], optional): Number of query-document
            pairs to allocate in an embedding table. This parameter is not
            used if a custom attraction module is provided.
        examination (Optional[Parameter | ParameterConfig], optional): Custom
            examination/bias parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        attraction (Optional[Parameter | ParameterConfig], optional): Custom
            attraction/relevance parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        rngs (nnx.Rngs): A NNX random number generator used for model
            initialization and stochastic operations.

    Examples:

        model = UserBrowsingModel(
            positions=10,
            query_doc_pairs=1_000_000,
            rngs=nnx.Rngs(42)
        )

    References:
        Georges E. Dupret and Benjamin Piwowarski.
        "A user browsing model to predict search engine click data from past observations."
        In SIGIR 2008.
    """

    name = "UBM"

    def __init__(
        self,
        positions: int,
        query_doc_pairs: Optional[int] = None,
        examination: Optional[Parameter | ParameterConfig] = None,
        attraction: Optional[Parameter | ParameterConfig] = None,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.positions = positions
        self.examination = init_parameter(
            "examination",
            examination,
            default_config_fn=default_ubm_examination_config,
            default_config_args={"positions": positions},
            rngs=rngs,
        )
        self.attraction = init_parameter(
            "attraction",
            attraction,
            default_config_fn=default_attraction_config,
            default_config_args={"query_doc_pairs": query_doc_pairs},
            rngs=rngs,
        )

    def compute_loss(self, batch: Dict, aggregate: bool = True):
        y_true = batch["clicks"]
        y_predict = self.predict_conditional_clicks(batch)
        return binary_cross_entropy(
            y_predict,
            y_true,
            where=batch["mask"],
            log_probs=True,
            aggregate=aggregate,
        )

    def predict_conditional_clicks(self, batch: Dict) -> Array:
        clicks = batch["clicks"]
        positions = batch["positions"]

        last_clicked_positions = self._last_clicked_positions(positions, clicks)
        exam_log_probs = self.examination.log_prob(
            self._examination_parameters(
                positions,
                last_clicked_positions,
            )
        )
        attr_log_probs = self.attraction.log_prob(batch)
        click_log_probs = exam_log_probs + attr_log_probs

        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_clicks(self, batch: Dict):
        mask = batch["mask"]
        positions = batch["positions"]
        n_batch, n_positions = positions.shape

        click_log_probs = jnp.zeros((n_batch, n_positions))
        attr_log_probs = self.attraction.log_prob(batch)

        for current_idx in range(n_positions):
            scenario_log_probs = []

            for last_clicked_idx in range(-1, current_idx):
                # Each scenario represents one possible browsing history:
                # Predict clicks at the current_idx given the last clicked doc is at last_clicked_idx.
                last_click_log_prob = self._get_last_click_log_prob(
                    click_log_probs=click_log_probs,
                    last_clicked_idx=last_clicked_idx,
                )
                no_clicks_log_prob = self._compute_no_clicks_between_log_prob(
                    positions=positions,
                    attr_log_probs=attr_log_probs,
                    last_clicked_idx=last_clicked_idx,
                    current_idx=current_idx,
                )
                current_click_log_prob = self._compute_current_click_log_prob(
                    positions=positions,
                    attr_log_probs=attr_log_probs,
                    current_idx=current_idx,
                    last_clicked_idx=last_clicked_idx,
                )
                # The click probability of one scenario consists of:
                # The prob. of the last item to be clicked, no clicks between the
                # last item and the current item, and the current click probability.
                scenario_log_prob = (
                    last_click_log_prob + no_clicks_log_prob + current_click_log_prob
                )
                scenario_log_prob = jnp.where(
                    mask[:, current_idx], scenario_log_prob, -jnp.inf
                )
                scenario_log_probs.append(scenario_log_prob)

            # Marginalize over all scenarios:
            scenario_log_probs = jnp.stack(scenario_log_probs, axis=-1)
            scenario_log_probs = jax.scipy.special.logsumexp(
                scenario_log_probs,
                axis=-1,
            )
            click_log_probs = click_log_probs.at[:, current_idx].set(scenario_log_probs)

        return click_log_probs

    def predict_relevance(self, batch: Dict) -> Array:
        return self.attraction.log_prob(batch)

    def sample(self, batch: Dict, rngs: nnx.Rngs) -> UserBrowsingModelOutput:
        mask = batch["mask"]
        positions = batch["positions"]
        n_batch, n_positions = positions.shape

        clicks = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        examination = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        last_clicked_positions = jnp.zeros(n_batch, dtype=positions.dtype)

        attr_probs = self.attraction.prob(batch)
        attraction = mask & jax.random.bernoulli(rngs(), attr_probs)

        for idx in range(n_positions):
            exam_probs = self.examination.prob(
                self._examination_parameters(
                    positions[:, idx],
                    last_clicked_positions,
                )
            )
            examination_at_idx = jax.random.bernoulli(rngs(), p=exam_probs)
            examination = examination.at[:, idx].set(mask[:, idx] & examination_at_idx)
            clicks = clicks.at[:, idx].set(examination[:, idx] & attraction[:, idx])

            last_clicked_positions = jnp.where(
                clicks[:, idx],
                positions[:, idx],
                last_clicked_positions,
            )

        return UserBrowsingModelOutput(
            clicks=clicks,
            examination=examination,
            attraction=attraction,
        )

    def _examination_parameters(self, positions, last_clicked_positions):
        examination_idx = positions * self.positions + last_clicked_positions
        return {"examination_idx": examination_idx}

    def _get_last_click_log_prob(
        self,
        click_log_probs: Array,
        last_clicked_idx: int,
    ) -> Array:
        """
        Get log probability of the last click (or zero if no previous click).
        """
        if last_clicked_idx == -1:
            return jnp.zeros(click_log_probs.shape[0])
        else:
            return click_log_probs[:, last_clicked_idx]

    def _compute_no_clicks_between_log_prob(
        self,
        positions: Array,
        attr_log_probs: Array,
        last_clicked_idx: int,
        current_idx: int,
    ) -> Array:
        """
        Compute log probability of no clicks between last_clicked_idx and current_idx.
        """
        log_prob = jnp.zeros(positions.shape[0])

        for intermediate_idx in range(last_clicked_idx + 1, current_idx):
            intermediate_positions = positions[:, intermediate_idx]
            last_clicked_positions = self._get_last_clicked_positions(
                positions, last_clicked_idx
            )
            exam_log_prob = self.examination.log_prob(
                self._examination_parameters(
                    intermediate_positions, last_clicked_positions
                )
            )
            click_log_prob = exam_log_prob + attr_log_probs[:, intermediate_idx]
            no_click_log_prob = log1mexp(click_log_prob)
            log_prob += no_click_log_prob

        return log_prob

    def _compute_current_click_log_prob(
        self,
        positions: Array,
        attr_log_probs: Array,
        current_idx: int,
        last_clicked_idx: int,
    ) -> Array:
        """
        Compute log probability of click at current position given last clicked position.
        """
        # Get actual positions (not indices) for parameter lookup:
        current_positions = positions[:, current_idx]
        last_clicked_positions = self._get_last_clicked_positions(
            positions, last_clicked_idx
        )
        exam_log_prob = self.examination.log_prob(
            self._examination_parameters(current_positions, last_clicked_positions)
        )

        return exam_log_prob + attr_log_probs[:, current_idx]

    def _get_last_clicked_positions(
        self,
        positions: Array,
        last_clicked_idx: int,
    ) -> Array:
        """
        Get the actual position values for the last clicked index.
        """
        if last_clicked_idx == -1:
            return jnp.zeros(positions.shape[0], dtype=positions.dtype)
        else:
            return positions[:, last_clicked_idx]

    @staticmethod
    def _last_clicked_positions(positions: Array, clicks: Array) -> Array:
        """
        Find the position of the last clicked document for each position.
        Formula: r' = max{k ∈ {0,...,r-1} : c_k = 1}
        """
        # Filter clicked positions, e.g.: [1, 2, 3, 4], [1, 0, 0, 1] -> [1, 0, 0, 4]
        clicked_positions = jnp.where(clicks == 1, positions, 0)
        # Find the last clicked position for each item: [1, 0, 0, 4] -> [1, 1, 1, 4]
        # Assumes positions are sorted in ascending order!
        clicked_positions = lax.cummax(clicked_positions, axis=1)
        # Shift the clicked positions to the right to align with the next item:
        clicked_positions = jnp.roll(clicked_positions, shift=1, axis=1)
        # Set the first position to 0, as there is no previously clicked position:
        return clicked_positions.at[:, 0].set(0)

Unconditional click probability

As examination in the UBM depends on the last clicked position, predicting unconditional clicks on a new list of documents requires marginalizing over all possible last click positions \(i < k\) before our current position:

\[\log P(C = 1 \mid d, k) = \log \left( \sum_{i=0}^{k - 1} P(C=1 \mid d_i, i) \cdot \left(\prod_{j=i+1}^{k - 1} (1 - \theta_{j,i}\gamma_{d_j})\right) \theta_{k,i}\gamma_{d} \right).\]

Each term in the sum represents a path to the current document: the probability of clicking at a previous rank \(i\), then not clicking on anything until rank \(k\), and finally examining and clicking the document at rank \(k\) given \(i\) was the last clicked position.

Conditional click probability

The UBM assumes that examination at position \(k\) depends also on the position of the last clicked document \(k'\). Similar to the PBM, users only click on examined and attractive documents:

\[\log P(C=1 \mid d, k, C_{<k}) = \log \theta_{k, k'} + \log \gamma_{d}.\]

Dependent Click Model (DCM)

Bases: ClickModel

The Dependent Click Model (DCM) extends the cascade model to allow multiple clicks per session by introducing rank-dependent continuation probabilities.

Parameters:

Name Type Description Default
positions Optional[int]

Number of positions used to allocate continuation probabilities. This parameter is not used if a custom continuation module is provided.

None
query_doc_pairs Optional[int]

Number of query-document pairs to allocate in an embedding table. This parameter is not used if a custom attraction module is provided.

None
attraction Optional[Parameter | ParameterConfig]

Custom attraction parameter, which can be a parameter config or any subclass of the Parameter base class.

None
continuation Optional[Parameter | ParameterConfig]

Custom continuation parameter deciding how likely a user is to continue their browsing session after clicking a document at the current position. Can be a parameter config or any subclass of the Parameter base class.

None
rngs Rngs

A NNX random number generator used for model initialization and stochastic operations.

required

Examples:

model = DependentClickModel(
    positions=10,
    query_doc_pairs=1_000_000,
    rngs=nnx.Rngs(42)
)
References

Fan Guo, Chao Liu, and Yi Min Wang. "Efficient multiple-click models in web search." In WSDM 2009.

Source code in clax/models/dcm.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
class DependentClickModel(ClickModel):
    """

    The Dependent Click Model (DCM) extends the cascade model to allow multiple clicks
    per session by introducing rank-dependent continuation probabilities.

    Args:
        positions (Optional[int], optional): Number of positions used to allocate
            continuation probabilities. This parameter is not used if a custom
            continuation module is provided.
        query_doc_pairs (Optional[int], optional): Number of query-document
            pairs to allocate in an embedding table. This parameter is not
            used if a custom attraction module is provided.
        attraction (Optional[Parameter | ParameterConfig], optional): Custom
            attraction parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        continuation (Optional[Parameter | ParameterConfig], optional): Custom
            continuation parameter deciding how likely a user is to continue their
            browsing session after clicking a document at the current position.
            Can be a parameter config or any subclass of the Parameter base class.
        rngs (nnx.Rngs): A NNX random number generator used for model
            initialization and stochastic operations.

    Examples:

        model = DependentClickModel(
            positions=10,
            query_doc_pairs=1_000_000,
            rngs=nnx.Rngs(42)
        )

    References:
        Fan Guo, Chao Liu, and Yi Min Wang.
        "Efficient multiple-click models in web search."
        In WSDM 2009.
    """

    name = "DCM"

    def __init__(
        self,
        positions: Optional[int] = None,
        query_doc_pairs: Optional[int] = None,
        attraction: Optional[Parameter | ParameterConfig] = None,
        continuation: Optional[Parameter | ParameterConfig] = None,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.attraction = init_parameter(
            "attraction",
            attraction,
            default_config_fn=default_attraction_config,
            default_config_args={"query_doc_pairs": query_doc_pairs},
            rngs=rngs,
        )
        self.continuation = init_parameter(
            "continuation",
            continuation,
            default_config_fn=default_continuation_config,
            default_config_args={"positions": positions},
            rngs=rngs,
        )

    def compute_loss(self, batch: Dict, aggregate: bool = True):
        y_true = batch["clicks"]
        y_predict = self.predict_conditional_clicks(batch)
        return binary_cross_entropy(
            y_predict,
            y_true,
            where=batch["mask"],
            log_probs=True,
            aggregate=aggregate,
        )

    def predict_conditional_clicks(self, batch: Dict) -> Array:
        clicks = batch["clicks"]
        log_probs = self._get_log_probabilities(batch)

        # Initialize: first document always examined (log(1) = 0):
        n_batch, n_positions = clicks.shape
        exam_log_probs = jnp.zeros((n_batch, n_positions))

        # Compute examination probabilities based on click history:
        for idx in range(n_positions - 1):
            exam_after_click = log_probs["cont"][:, idx]
            exam_after_no_click = self._log_examination_after_no_click(
                current_exam_log_prob=exam_log_probs[:, idx],
                attraction_log_prob=log_probs["attr"][:, idx],
                non_attraction_log_prob=log_probs["non_attr"][:, idx],
            )
            exam_log_probs = exam_log_probs.at[:, idx + 1].set(
                jnp.where(
                    clicks[:, idx],
                    exam_after_click,
                    exam_after_no_click,
                )
            )

        click_log_probs = exam_log_probs + log_probs["attr"]
        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_clicks(self, batch: Dict) -> Array:
        log_probs = self._get_log_probabilities(batch)

        # Compute examination log probability increments for each position:
        exam_log_probs = self._log_examination_step(
            attr_log_prob=log_probs["attr"],
            non_attr_log_prob=log_probs["non_attr"],
            cont_log_prob=log_probs["cont"],
        )
        exam_log_probs = jnp.roll(exam_log_probs, shift=1, axis=-1)
        exam_log_probs = exam_log_probs.at[:, 0].set(0)
        exam_log_probs = jnp.cumsum(exam_log_probs, axis=-1)

        click_log_probs = exam_log_probs + log_probs["attr"]
        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_relevance(self, batch: Dict) -> Array:
        return self.attraction.log_prob(batch)

    def sample(self, batch: Dict, rngs: nnx.Rngs) -> Array:
        mask = batch["mask"]
        attr_probs = self.attraction.prob(batch)
        continuation = self.continuation.prob(batch)

        n_batch, n_positions = mask.shape
        clicks = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        attraction = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        examination = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)

        # Always examine first position (if valid):
        examination = examination.at[:, 0].set(mask[:, 0])

        for idx in range(n_positions):
            attraction_at_idx = jax.random.bernoulli(rngs(), attr_probs[:, idx])
            attraction = attraction.at[:, idx].set(mask[:, idx] & attraction_at_idx)
            clicks = clicks.at[:, idx].set(examination[:, idx] & attraction[:, idx])

            if idx < n_positions - 1:
                # Determine continuation probability:
                # - If clicked: use continuation probability
                # - If examined but not clicked: always continue (prob=1)
                # - If not examined: never continue (prob=0)
                continuation_prob = jnp.where(
                    examination[:, idx],
                    jnp.where(clicks[:, idx], continuation[:, idx], 1.0),
                    0.0,
                )
                should_continue = jax.random.bernoulli(rngs(), p=continuation_prob)
                examination = examination.at[:, idx + 1].set(
                    should_continue & batch["mask"][:, idx + 1]
                )

        return DependentClickModelOutput(
            clicks=clicks,
            examination=examination,
            attraction=attraction,
        )

    def _get_log_probabilities(self, batch: Dict) -> Dict[str, Array]:
        attr_logits = self.attraction.logit(batch)
        attr_log_probs = logits_to_log_probs(attr_logits)
        non_attr_log_probs = logits_to_complement_log_probs(attr_logits)
        cont_log_probs = self.continuation.log_prob(batch)

        return {
            "attr": attr_log_probs,
            "non_attr": non_attr_log_probs,
            "cont": cont_log_probs,
        }

    @staticmethod
    def _log_examination_after_no_click(
        current_exam_log_prob: Array,
        attraction_log_prob: Array,
        non_attraction_log_prob: Array,
    ) -> Array:
        """
        Compute examination probability after not clicking.
        Formula: P(E_{r+1} = 1 | E_r = 1, C_r = 0) = [(1-α_r) × ε_r] / [1 - α_r × ε_r]
        In log space: log(1-α_r) + log(ε_r) - log(1 - α_r × ε_r)
        """
        numerator_log = current_exam_log_prob + non_attraction_log_prob
        denominator_log = log1mexp(current_exam_log_prob + attraction_log_prob)
        return numerator_log - denominator_log

    @staticmethod
    def _log_examination_step(
        attr_log_prob: Array,
        non_attr_log_prob: Array,
        cont_log_prob: Array,
    ) -> Array:
        """
        Compute one step of unconditional examination log probability.
        Formula: P(E_{r+1} = 1) = α_r × λ_r + (1-α_r) × 1
        In log space: log[α_r × λ_r + (1-α_r)]
        """
        return jnp.logaddexp(cont_log_prob + attr_log_prob, non_attr_log_prob)

Unconditional click probability

The DCM assumes that users examine a list from top to bottom, click on attractive items \(\gamma_d\), and after clicking have a rank-dependent probability \(\lambda_k\) to continue browsing:

\[ \begin{split} \log P(C=1 \mid d, k) &= \log(\epsilon_{k}) + \log(\gamma_{d})\\ \log(\epsilon_{k+1}) &= \log(\epsilon_k) + \log(\gamma_{d_k} \lambda_k + (1 - \gamma_{d_k})).\\ \end{split} \]

Conditional click probability

Examination in the DCM changes based on the observed clicks. If a user clicks on a document, they continue to the next rank with probability \(\lambda_k\) and if they do not click, we calculate the posterior probability of examining the next rank given that we observed no click using Bayes' rule:

\[ \begin{split} \log P(C=1 \mid d, k, C_{<k}) &= \log(\epsilon_{k}) + \log(\gamma_{d})\\ \log(\epsilon_{k+1}) &= \log\left(c_k \lambda_k + (1 - c_k) \frac{(1 - \gamma_{d_k}) \epsilon_k}{1 - \gamma_{d_k} \epsilon_k}\right).\\ \end{split} \]

Click Chain Model (CCM)

Bases: ClickModel

Click Chain Model (CCM)

The CCM extends the DCM to allow users abandoning a session without any clicks by introducing probabilities for users to continue examination after not clicking a document, clicking but not being satisfied, and clicking and being satisfied. The CCM assumes that document attraction and satisfaction probabilities are identical.

Parameters:

Name Type Description Default
query_doc_pairs Optional[int]

Number of query-document pairs to allocate in an embedding table. This parameter is not used if a custom attraction module is provided.

None
attraction Optional[Parameter | ParameterConfig]

Custom attraction / satisfaction parameter, which can be a parameter config or any subclass of the Parameter base class.

None
rngs Rngs

A NNX random number generator used for model initialization and stochastic operations.

required

Examples:

model = ClickChainModel(
    query_doc_pairs=1_000_000,
    rngs=nnx.Rngs(42)
)
References

Fan Guo, Chao Liu, Anitha Kannan, Tom Minka, Michael Taylor, Yi-Min Wang, and Christos Faloutsos. "Click Chain Model in Web Search." In WWW 2009.

Source code in clax/models/ccm.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
class ClickChainModel(ClickModel):
    """
    Click Chain Model (CCM)

    The CCM extends the DCM to allow users abandoning a session without any clicks
    by introducing probabilities for users to continue examination after not clicking
    a document, clicking but not being satisfied, and clicking and being satisfied.
    The CCM assumes that document attraction and satisfaction probabilities are identical.

    Args:
        query_doc_pairs (Optional[int], optional): Number of query-document
            pairs to allocate in an embedding table. This parameter is not
            used if a custom attraction module is provided.
        attraction (Optional[Parameter | ParameterConfig], optional): Custom
            attraction / satisfaction parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        rngs (nnx.Rngs): A NNX random number generator used for model
            initialization and stochastic operations.

    Examples:

        model = ClickChainModel(
            query_doc_pairs=1_000_000,
            rngs=nnx.Rngs(42)
        )

    References:
        Fan Guo, Chao Liu, Anitha Kannan, Tom Minka, Michael Taylor, Yi-Min Wang, and Christos Faloutsos.
        "Click Chain Model in Web Search."
        In WWW 2009.
    """

    name = "CCM"

    def __init__(
        self,
        query_doc_pairs: Optional[int] = None,
        attraction: Optional[Parameter | ParameterConfig] = None,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.attraction = init_parameter(
            "attraction",
            attraction,
            default_config_fn=default_attraction_config,
            default_config_args={"query_doc_pairs": query_doc_pairs},
            rngs=rngs,
        )
        # Continuation are global variables that don't depend on features.
        # These might be configurable in future versions if useful:
        self.continuation_exam_no_click = GlobalParameter(rngs=rngs)
        self.continuation_click_satisfied = GlobalParameter(rngs=rngs)
        self.continuation_click_not_satisfied = GlobalParameter(rngs=rngs)

    def compute_loss(self, batch: Dict, aggregate: bool = True):
        y_true = batch["clicks"]
        y_predict = self.predict_conditional_clicks(batch)
        return binary_cross_entropy(
            y_predict,
            y_true,
            where=batch["mask"],
            log_probs=True,
            aggregate=aggregate,
        )

    def predict_conditional_clicks(self, batch: Dict) -> Array:
        clicks = batch["clicks"]
        n_batch, n_positions = clicks.shape
        log_probs = self._get_log_probabilities(batch)

        # First position is always examined (log(1) = 0):
        exam_log_probs = jnp.zeros((n_batch, n_positions))

        for idx in range(n_positions - 1):
            exam_after_click = self._log_examination_after_click(
                rel_log_prob=log_probs["rel"][:, idx],
                non_rel_log_prob=log_probs["non_rel"][:, idx],
                tau2_log_prob=log_probs["tau2"],
                tau3_log_prob=log_probs["tau3"],
            )
            exam_after_no_click = self._log_examination_after_no_click(
                current_exam_log_prob=exam_log_probs[:, idx],
                rel_log_prob=log_probs["rel"][:, idx],
                non_rel_log_prob=log_probs["non_rel"][:, idx],
                tau1_log_prob=log_probs["tau1"],
            )

            exam_log_probs = exam_log_probs.at[:, idx + 1].set(
                jnp.where(
                    clicks[:, idx],
                    exam_after_click,
                    exam_after_no_click,
                )
            )

        click_log_probs = exam_log_probs + log_probs["rel"]
        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_clicks(self, batch: Dict) -> Array:
        log_probs = self._get_log_probabilities(batch)

        exam_log_probs = self._log_examination_step(
            rel_log_prob=log_probs["rel"],
            non_rel_log_prob=log_probs["non_rel"],
            tau1_log_prob=log_probs["tau1"],
            tau2_log_prob=log_probs["tau2"],
            tau3_log_prob=log_probs["tau3"],
        )
        exam_log_probs = jnp.roll(exam_log_probs, shift=1, axis=-1)
        exam_log_probs = exam_log_probs.at[:, 0].set(0)
        exam_log_probs = jnp.cumsum(exam_log_probs, axis=-1)

        click_log_probs = exam_log_probs + log_probs["rel"]
        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_relevance(self, batch: Dict) -> Array:
        return self.attraction.log_prob(batch)

    def sample(self, batch: Dict, rngs: nnx.Rngs) -> ClickChainModelOutput:
        rel_probs = self.attraction.prob(batch)
        tau1 = self.continuation_exam_no_click.prob()
        tau2 = self.continuation_click_not_satisfied.prob()
        tau3 = self.continuation_click_satisfied.prob()
        mask = batch["mask"]

        n_batch, n_positions = rel_probs.shape
        clicks = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        examination = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        attraction = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        satisfaction = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)

        # If valid, always examine the first item:
        examination = examination.at[:, 0].set(batch["mask"][:, 0])

        for idx in range(n_positions):
            attraction_at_idx = jax.random.bernoulli(rngs(), rel_probs[:, idx])
            attraction = attraction.at[:, idx].set(mask[:, idx] & attraction_at_idx)
            clicks = clicks.at[:, idx].set(examination[:, idx] & attraction[:, idx])

            if idx < n_positions - 1:
                sat_probs = jnp.where(clicks[:, idx], rel_probs[:, idx], 0.0)
                satisfaction = satisfaction.at[:, idx].set(
                    jax.random.bernoulli(rngs(), p=sat_probs)
                )

                continue_after_click_satisfied = clicks[:, idx] & satisfaction[:, idx]
                continue_after_click_not_satisfied = (
                    clicks[:, idx] & ~satisfaction[:, idx]
                )
                continue_after_no_click = examination[:, idx] & ~clicks[:, idx]

                continuation_probs = jnp.where(
                    continue_after_click_satisfied,
                    tau3,
                    jnp.where(
                        continue_after_click_not_satisfied,
                        tau2,
                        jnp.where(continue_after_no_click, tau1, 0.0),
                    ),
                )

                should_continue = jax.random.bernoulli(rngs(), p=continuation_probs)
                examination = examination.at[:, idx + 1].set(
                    should_continue & mask[:, idx + 1]
                )

        return ClickChainModelOutput(
            clicks=clicks,
            examination=examination,
            attraction=attraction,
            satisfaction=satisfaction,
        )

    def _get_log_probabilities(self, batch: Dict) -> Dict[str, Array]:
        rel_logits = self.attraction.logit(batch)
        rel_log_probs = logits_to_log_probs(rel_logits)
        non_rel_log_probs = logits_to_complement_log_probs(rel_logits)

        tau1_log_prob = self.continuation_exam_no_click.log_prob()
        tau2_log_prob = self.continuation_click_not_satisfied.log_prob()
        tau3_log_prob = self.continuation_click_satisfied.log_prob()

        return {
            "rel": rel_log_probs,
            "non_rel": non_rel_log_probs,
            "tau1": tau1_log_prob,
            "tau2": tau2_log_prob,
            "tau3": tau3_log_prob,
        }

    @staticmethod
    def _log_examination_after_click(
        rel_log_prob: Array,
        non_rel_log_prob: Array,
        tau2_log_prob: Array,
        tau3_log_prob: Array,
    ) -> Array:
        """
        Compute log examination probability after clicking.
        Formula: P(E_{r+1} = 1 | E_r = 1, C_r = 1) = α_r × τ3 + (1 - α_r) × τ2
        In log space: log[α_r × τ3 + (1 - α_r) × τ2]
        """
        satisfied_log = rel_log_prob + tau3_log_prob
        not_satisfied_log = non_rel_log_prob + tau2_log_prob
        return jnp.logaddexp(satisfied_log, not_satisfied_log)

    @staticmethod
    def _log_examination_after_no_click(
        current_exam_log_prob: Array,
        rel_log_prob: Array,
        non_rel_log_prob: Array,
        tau1_log_prob: Array,
    ) -> Array:
        """
        Compute log examination probability after not clicking.
        Formula: P(E_{r+1} = 1 | E_r = 1, C_r = 0) = [(1 - α_r) × ε_r × τ1] / [1 - α_r × ε_r]
        In log space: log(1 - α_r) + log ε_r + log τ1 - log(1 - α_r × ε_r)
        """
        numerator_log = non_rel_log_prob + current_exam_log_prob + tau1_log_prob
        denominator_log = log1mexp(rel_log_prob + current_exam_log_prob)
        return numerator_log - denominator_log

    @staticmethod
    def _log_examination_step(
        rel_log_prob: Array,
        non_rel_log_prob: Array,
        tau1_log_prob: Array,
        tau2_log_prob: Array,
        tau3_log_prob: Array,
    ) -> Array:
        """
        Compute one step of unconditional examination log probability.
        Formula: P(E_{r+1} = 1) = α × ((1-α) × τ2 + α × τ3) + (1-α) × τ1
        In log space: log[α × ((1-α) × τ2 + α × τ3) + (1-α) × τ1]
        """
        attraction_term = rel_log_prob + jnp.logaddexp(
            non_rel_log_prob + tau2_log_prob,
            rel_log_prob + tau3_log_prob,
        )
        non_attraction_term = non_rel_log_prob + tau1_log_prob

        return jnp.logaddexp(attraction_term, non_attraction_term)

Unconditional click probability

The click chain model (CCM) extends the DCM assuming a total of three continuation scenarios that do not only explain continuation after clicking a document but also allow users to abandon a session without any clicks. First, \(\tau_1\) is the probability of a user continuing to the next document after not clicking on the current document. Second, if the user clicks on the current document but is not satisfied, \(\tau_2\) is the probability of the user continuing to the next position. And lastly, \(\tau_3\) is the probability that a user clicks on the current item, finds it satisfying, but still wants to continue to the next document:

\[ \begin{split} \log P(C=1 \mid d, k) &= \log(\gamma_d) + \log(\epsilon_k) \\ \log(\epsilon_{k+1}) &= \log(\epsilon_k) \\ &\quad + \log\left( \gamma_{d_k}((1-\gamma_{d_k})\tau_2 + \gamma_{d_k}\tau_3) \right. \\ &\quad \left. + (1-\gamma_{d_k})\tau_1 \right). \end{split} \]

Conditional click probability

When conditioning on the observed clicks, the update rule for the examination probability changes based on the user's action at the current rank. If a click occurred, we compute continuation based on satisfaction (equal to attractiveness \(\gamma_d\)) and the continuation probabilities \(\tau_2\) and \(\tau_3\). If no click was observed, we compute the posterior log probability of continuing to the next rank:

\[ \begin{split} \log P(C=1 \mid d, k, C_{<k}) &= \log(\gamma_d) + \log(\epsilon_k) \\ \log(\epsilon_{k+1}) &= c_k \left[ \log\left(\gamma_{d_k}\tau_3 + (1-\gamma_{d_k})\tau_2 \right) \right] \\ &\quad + (1-c_k) \left[ \log(1-\gamma_{d_k}) + \log(\epsilon_k) + \log(\tau_1) - \log(1 - \gamma_{d_k}\epsilon_k) \right]. \end{split} \]

Dynamic Bayesian Network (DBN)

Bases: ClickModel

The DBN extends the cascade model by introducing separate attraction and satisfaction parameters, allowing users to continue examining after clicks if they are not satisfied with the clicked item. Note that attraction and satisfaction parameters can be customized to use completely different sets of features.

Parameters:

Name Type Description Default
query_doc_pairs Optional[int]

Number of query-document pairs to allocate in an embedding table. This parameter is required if no custom attraction or satisfaction module is provided.

None
attraction Optional[Parameter | ParameterConfig]

Custom attraction/relevance parameter, which can be a parameter config or any subclass of the Parameter base class.

None
satisfaction Optional[Parameter | ParameterConfig]

Custom attraction/relevance parameter, which can be a parameter config or any subclass of the Parameter base class.

None
rngs Rngs

A NNX random number generator used for model initialization and stochastic operations.

required

Examples:

model = DynamicBayesianNetwork(
    query_doc_pairs=1_000_000,
    rngs=nnx.Rngs(42)
)
References

Olivier Chapelle and Ya Zhang. "A dynamic bayesian network click model for web search ranking." In WWW 2009.

Source code in clax/models/dbn.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
class DynamicBayesianNetwork(ClickModel):
    """
    The DBN extends the cascade model by introducing separate attraction and satisfaction
    parameters, allowing users to continue examining after clicks if they are not
    satisfied with the clicked item. Note that attraction and satisfaction parameters
    can be customized to use completely different sets of features.

    Args:
        query_doc_pairs (Optional[int], optional): Number of query-document
            pairs to allocate in an embedding table. This parameter is required
            if no custom attraction or satisfaction module is provided.
        attraction (Optional[Parameter | ParameterConfig], optional): Custom
            attraction/relevance parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        satisfaction (Optional[Parameter | ParameterConfig], optional): Custom
            attraction/relevance parameter, which can be a parameter config or any
            subclass of the Parameter base class.
        rngs (nnx.Rngs): A NNX random number generator used for model
            initialization and stochastic operations.

    Examples:

        model = DynamicBayesianNetwork(
            query_doc_pairs=1_000_000,
            rngs=nnx.Rngs(42)
        )

    References:
        Olivier Chapelle and Ya Zhang.
        "A dynamic bayesian network click model for web search ranking."
        In WWW 2009.
    """

    def __init__(
        self,
        query_doc_pairs: Optional[int] = None,
        attraction: Optional[Parameter | ParameterConfig] = None,
        satisfaction: Optional[Parameter | ParameterConfig] = None,
        fix_continuation: bool = False,
        *,
        rngs: nnx.Rngs,
    ):
        super().__init__()

        self.attraction = init_parameter(
            "attraction",
            attraction,
            default_config_fn=default_attraction_config,
            default_config_args={"query_doc_pairs": query_doc_pairs},
            rngs=rngs,
        )
        self.satisfaction = init_parameter(
            "satisfaction",
            satisfaction,
            default_config_fn=default_satisfaction_config,
            default_config_args={"query_doc_pairs": query_doc_pairs},
            rngs=rngs,
        )

        self.fix_continuation = fix_continuation
        self.name = "SDBN" if fix_continuation else "DBN"
        self.continuation = GlobalParameter(rngs=rngs)

    def compute_loss(self, batch: Dict, aggregate: bool = True):
        y_true = batch["clicks"]
        y_predict = self.predict_conditional_clicks(batch)
        return binary_cross_entropy(
            y_predict,
            y_true,
            where=batch["mask"],
            log_probs=True,
            aggregate=aggregate,
        )

    def predict_conditional_clicks(self, batch: Dict) -> Array:
        clicks = batch["clicks"]
        log_probs = self._get_log_probabilities(batch)

        # Initialize: first document always examined (log(1) = 0):
        n_batch, n_positions = clicks.shape
        exam_log_probs = jnp.zeros((n_batch, n_positions))

        # Compute examination probabilities based on click history:
        for idx in range(n_positions - 1):
            exam_after_click = self._log_examination_after_click(
                non_sat_log_probs=log_probs["non_sat"][:, idx],
                cont_log_prob=log_probs["cont"],
            )
            exam_after_no_click = self._log_examination_after_no_click(
                current_exam_log_prob=exam_log_probs[:, idx],
                attraction_log_prob=log_probs["attr"][:, idx],
                non_attraction_log_prob=log_probs["non_attr"][:, idx],
                cont_log_prob=log_probs["cont"],
            )
            exam_log_probs = exam_log_probs.at[:, idx + 1].set(
                jnp.where(
                    clicks[:, idx],
                    exam_after_click,
                    exam_after_no_click,
                )
            )

        click_log_probs = exam_log_probs + log_probs["attr"]
        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_clicks(self, batch: Dict) -> Array:
        log_probs = self._get_log_probabilities(batch)

        exam_log_probs = self._log_examination_step(
            attr_log_probs=log_probs["attr"],
            non_attr_log_probs=log_probs["non_attr"],
            non_sat_log_probs=log_probs["non_sat"],
            cont_log_prob=log_probs["cont"],
        )
        exam_log_probs = jnp.roll(exam_log_probs, shift=1, axis=-1)
        exam_log_probs = exam_log_probs.at[:, 0].set(0)
        exam_log_probs = jnp.cumsum(exam_log_probs, axis=-1)

        click_log_probs = exam_log_probs + log_probs["attr"]
        return jnp.where(batch["mask"], click_log_probs, -jnp.inf)

    def predict_relevance(self, batch: Dict) -> Array:
        return self.attraction.log_prob(batch) + self.satisfaction.log_prob(batch)

    def sample(self, batch: Dict, rngs: nnx.Rngs) -> DynamicBayesianNetworkOutput:
        mask = batch["mask"]
        n_batch, n_positions = mask.shape

        attr_probs = self.attraction.prob(batch)
        sat_probs = self.satisfaction.prob(batch)
        continuation = (
            jnp.array([1.0]) if self.fix_continuation else self.continuation.prob()
        )

        clicks = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        examination = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        attraction = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)
        satisfaction = jnp.zeros((n_batch, n_positions), dtype=jnp.bool_)

        # Always examine the first item (if valid)
        examination = examination.at[:, 0].set(batch["mask"][:, 0])

        for idx in range(n_positions):
            attraction_at_idx = jax.random.bernoulli(rngs(), attr_probs[:, idx])
            attraction = attraction.at[:, idx].set(mask[:, idx] & attraction_at_idx)
            clicks = clicks.at[:, idx].set(examination[:, idx] & attraction[:, idx])

            if idx < n_positions - 1:
                # Sample user satisfaction only for clicked items:
                satisfaction_probs = jnp.where(clicks[:, idx], sat_probs[:, idx], 0.0)
                satisfaction = satisfaction.at[:, idx].set(
                    jax.random.bernoulli(rngs(), p=satisfaction_probs)
                )

                # Users continue when not satisfied after click:
                continue_after_click = clicks[:, idx] & ~satisfaction[:, idx]
                # Users continue after examining but clicking the current item:
                continue_without_click = examination[:, idx] & ~clicks[:, idx]
                continuation_probs = continuation * (
                    continue_after_click | continue_without_click
                )
                should_continue = jax.random.bernoulli(rngs(), p=continuation_probs)
                examination = examination.at[:, idx + 1].set(
                    should_continue & batch["mask"][:, idx + 1]
                )

        return DynamicBayesianNetworkOutput(
            clicks=clicks,
            examination=examination,
            attraction=attraction,
            satisfaction=satisfaction,
        )

    def _get_log_probabilities(self, batch: Dict) -> Dict[str, Array]:
        attr_logits = self.attraction.logit(batch)
        attr_log_probs = logits_to_log_probs(attr_logits)
        non_attr_log_probs = logits_to_complement_log_probs(attr_logits)

        sat_logits = self.satisfaction.logit(batch)
        non_sat_log_probs = logits_to_complement_log_probs(sat_logits)

        cont_log_prob = (
            jnp.array([0.0]) if self.fix_continuation else self.continuation.log_prob()
        )

        return {
            "attr": attr_log_probs,
            "non_attr": non_attr_log_probs,
            "non_sat": non_sat_log_probs,
            "cont": cont_log_prob,
        }

    @staticmethod
    def _log_examination_after_click(
        non_sat_log_probs: Array,
        cont_log_prob: Array,
    ) -> Array:
        """
        Compute log examination probability after clicking.
        Formula: e_{r+1} = (1 - σ_r) × γ
        In log space: log ε_{r+1} = log(1 - σ_r) + log γ
        """
        return cont_log_prob + non_sat_log_probs

    @staticmethod
    def _log_examination_after_no_click(
        current_exam_log_prob: Array,
        attraction_log_prob: Array,
        non_attraction_log_prob: Array,
        cont_log_prob: Array,
    ) -> Array:
        """
        Compute log examination probability after not clicking.
        Formula: P(E_{r+1} = 1 | E_r = 1, C_r = 0) = [(1 - α_r) × ε_r × γ] / [1 - α_r × ε_r]
        In log space: log ε_{r+1} = log(1 - α_r) + log ε_r + log γ - log(1 - α_r × ε_r)
        """
        numerator_log = current_exam_log_prob + non_attraction_log_prob + cont_log_prob
        denominator_log = log1mexp(current_exam_log_prob + attraction_log_prob)
        return numerator_log - denominator_log

    @staticmethod
    def _log_examination_step(
        attr_log_probs: Array,
        non_attr_log_probs: Array,
        non_sat_log_probs: Array,
        cont_log_prob: Array,
    ) -> Array:
        """
        Compute one step of unconditional examination log probability.
        Formula: P(E_{r+1} = 1) = γ × [α(1-σ) + (1-α)]
        In log space: log(γ) + log[α(1-σ) + (1-α)]
        """
        return cont_log_prob + jnp.logaddexp(
            attr_log_probs + non_sat_log_probs, non_attr_log_probs
        )

Unconditional click probability

The DBN model separates the concepts of a document being attractive (\(\gamma_d\)) and being satisfying (\(\sigma_d\)). A user stops their search only if they click on an attractive document and are satisfied by it. If they do not click or are not satisfied by the clicked document, they continue browsing with a global continuation probability \(\lambda\):

\[ \begin{split} \log P(C=1 \mid d, k) &= \log(\gamma_d) + \log(\epsilon_k) \\ \log(\epsilon_{k+1}) &= \log(\epsilon_k) + \log(\lambda) + \log(1 - \gamma_{d_k}\sigma_{d_k}).\\ \end{split} \]

Conditional click probability

The conditional click probability takes the user's actions in the current session into account. If a click was observed, we compute the probability of continuation based on satisfaction. If no click was observed, we compute the posterior probability of continuing to the next item:

\[ \begin{split} \log P(C=1 \mid d, k, C_{<k}) &= \log(\gamma_d) + \log(\epsilon_k) \\ \log(\epsilon_{k+1}) &= \log(\lambda) + c_k \left[ \log(1 - \sigma_{d_k}) \right] \\ &\quad + (1-c_k)\left[ \log(1-\gamma_{d_k}) + \log(\epsilon_k) - \log(1 - \gamma_{d_k}\epsilon_k) \right]. \end{split} \]