Numerical Stability
Optimizing complex likelihood expressions using gradient descent requires attention to numerical stability. The marginal likelihoods of many common click models contain products of small probabilities, which can lead to numerical underflow in finite-precision computer arithmetic. Below, we cover the techniques CLAX uses to stabilize complex likelihood expressions by performing all probability computations in log-space.
Multiplication
By moving to log-probabilities, products of probabilities simplify to sums (and division to subtraction):
which essentially eliminates the concern of numerical underflow when multiplying small probabilities.
Addition
While multiplication becomes stable (and faster) in log-space, the addition of probabilities becomes more complicated as it requires first exponentiating log probabilities. This reintroduces the problems we seek to avoid, as exponentiating large positive inputs lead to overflow and exponentiating large negative inputs lead to underflow. The standard solution is to avoid large inputs to the \(\exp(\cdot)\) operation via the log-sum-exp trick:1
where \(a = (a_1, \dots, a_n)\) is a vector of log values and \(a_{\text{max}} = \max_i(a_i)\) is the maximum input value. The trick is prevalent in probabilistic modeling, and we also use it to transform the output logits of neural networks \(x \in \mathbb{R}\) to log-probabilities by implementing numerically stable versions of the log-sigmoid functions:
Complements and cancellation
Sometimes we need to compute the log of a complement \(\log(1 - p)\), e.g., in the binary-cross entropy loss or when computing log-posteriors in the DBN. Performing this step directly from log-probability \(\log p\) requires computing: \(\log(1 - \exp(\log p))\).
This expression is numerically unstable in two ways: (i) underflow: when \(p\) is very small, \(\log p\) is very negative, causing \(\exp(\log p)\) to underflow to zero; and (ii) catastrophic cancellation: when \(p \approx 1\), we have \(\exp(\log p) \approx 1\), making \(1 - \exp(\log p) \approx 0\), since subtracting nearly equal floating point numbers leads to a loss of precision.2
Therefore, we compute \(\texttt{log1mexp}(x)\) as proposed by Mächler3 and adopted by major frameworks such as TensorFlow and JAX. Mächler proposes a piecewise approximation that switches between two stable expressions that are precise in different input ranges.4 For a log-probability \(a \in \mathbb{R}, a \leq 0\):
The implementation relies on the standard functions \(\text{log1p}(x)\), which accurately computes \(\log(1 + x)\), and \(\text{expm1}(x)\), which accurately computes \(\exp(x) - 1\), to avoid catastrophic cancellation.
To summarize, CLAX performs all probability computations in log space for increased numerical stability, avoiding underflow and overflow as well as catastrophic cancellation.
-
Pierre Blanchard, Desmond J. Higham, and Nicholas J. Higham. "Accurate Computation of the Log-Sum-Exp and Softmax Functions". arXiv preprint arXiv:1909.03469, 2019. ↩
-
David Goldberg. "What Every Computer Scientist Should Know about Floating-Point Arithmetic". In ACM Computing Surveys, 1999. ↩
-
Martin Mächler. "Accurately Computing \(\log(1 - \exp(-|a|))\) Assessed by the Rmpfr package". In The Comprehensive R Archive Network, 2012. ↩
-
Interested readers can find the motivation behind the switching point \(\log(2)\) under Section 2. ↩