Skip to content

Welcome to CLAX

CLAX is a framework for training and evaluating fast and flexible neural click models in JAX.

For example, training a User Browsing Model in CLAX is as simple as:

from clax import Trainer, UserBrowsingModel
from flax import nnx
from optax import adamw

model = UserBrowsingModel(
    query_doc_pairs=100_000_000,
    positions=10,
    rngs=nnx.Rngs(42),
)
trainer = Trainer(
    optimizer=adamw(0.003),
    epochs=50,
)
train_df = trainer.train(model, train_loader, val_loader)
test_df = trainer.test(model, test_loader)

However, the modular design of CLAX also allows for more complex models from two-tower models, mixture models, or plugging-in custom FLAX modules as model parameters. We provide usage examples for getting started under examples/.

Installation

CLAX requires JAX. For installing JAX with CUDA support, please refer to the JAX documentation. CLAX itself is available via pypi:

pip install clax-models

Documentation

Reference

If you use CLAX, please consider citing our paper:

@misc{hager2025clax,
  title = {CLAX: Fast and Flexible Neural Click Models in JAX},
  author  = {Philipp Hager and Onno Zoeter and Maarten de Rijke},
  year  = {2025},
  booktitle = {arxiv}
}