Loading [MathJax]/jax/output/HTML-CSS/jax.js

Sunday, July 17, 2011

Fast Approximate Lambert W

This is of questionable personal utility, after all, I've only seen Lambert's W function once in a machine learning context, and there it occurs as W(exp(x))x which is better to approximate directly. Nonetheless generating these fast approximate functions makes for amusing sport and has a much higher likelihood of being to useful to someone somewhere than, e.g., further developing my computer gaming ability.

The Branchless Lifestyle

Lambert W is a typical function in one respect: there is a nice asymptotic approximation that can be used to initialize a Householder method, but which is only applicable to part of the domain. In particular for large x, W(x)log(x)loglog(x)+loglog(x)log(x), which is great because logarithm is cheap. Unfortunately for x1 this cannot be evaluated, and for 1<x<2 it gives really poor results. A natural way to proceed would be to have a different initialization strategy for x<2.
1
2
3
4
5
6
7
8
9
10
if (x < 2)
  {
    // alternate initialization
  }
else
  {
    // use asymptotic approximation
  }
 
// householder steps here
There are two problems with this straightforward approach. In a scalar context this can frustrate the pipelined architecture of modern CPUs. In a vector context this is even more problematic because components might fall into different branches of the conditional.

What to do? It turns out statements like this
1
a = (x < 2) ? b : c
look like conditionals but need not be, since they can be rewritten as
1
a = f (x < 2) * b + (1 - f (x < 2)) * c
Here f is an indicator function which returns 0 or 1 depending upon the truth value of the argument. The SSE instruction set contains indicator functions for comparison tests, which when combined with ``floating point and'' instructions end up computing a branchless ternary operator.

The bottom line is that speculative execution can be made deterministic if both branches of a conditional are computed, and in simple enough cases there is direct hardware support for doing this quickly.

Branchless Lambert W

So the big idea here is to have an alternate initialization for the Householder step such that it can be computed in a branchless fashion, given that for large inputs the asymptotic approximation is used. Therefore I looked for an approximation of the form W(x)a+log(cx+d)loglog(cx+d)+loglog(cx+d)log(cx+d), where for large x, a=0, c=1, and d=0. I found values for a, c, d, and the cutoff value for x via Mathematica. (The curious can check out the Mathematica notebook). The vector version ends up looking like
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// WARNING: this code has been updated.  Do not use this version.
// Instead get the latest version from http://code.google.com/p/fastapprox
 
static inline v4sf
vfastlambertw (v4sf x)
{
  static const v4sf threshold = v4sfl (2.26445f);
 
  v4sf under = _mm_cmplt_ps (x, threshold);
  v4sf c = _mm_or_ps (_mm_and_ps (under, v4sfl (1.546865557f)),
                      _mm_andnot_ps (under, v4sfl (1.0f)));
  v4sf d = _mm_and_ps (under, v4sfl (2.250366841f));
  v4sf a = _mm_and_ps (under, v4sfl (-0.737769969f));
 
  v4sf logterm = vfastlog (c * x + d);
  v4sf loglogterm = vfastlog (logterm);
 
  v4sf w = a + logterm - loglogterm + loglogterm / logterm;
  v4sf expw = vfastexp (w);
  v4sf z = w * expw;
  v4sf p = x + z;
 
  return (v4sfl (2.0f) * x + w * (v4sfl (4.0f) * x + w * p)) /
         (v4sfl (2.0f) * expw + p * (v4sfl (2.0f) + w));
}
You can get the complete code from the fastapprox project.

Timing and Accuracy

Timing tests are done by compiling with -O3 -finline-functions -ffast-math, on a box running 64 bit Ubuntu Lucid (so gcc 4:4.4.3-1ubuntu1 and libc 2.11.1-0ubuntu7.6). I also measured average relative accuracy for x distributed as (12U(1/e,1)+12U(0,100)), i.e., a 50-50 draw from two uniform distributions. Accuracy is compared to 20 iterations of Newton's method with a initial point of 0 when x<5 and the asymptotic approximation otherwise. I also tested the gsl implementation which is much higher accuracy but significantly slower.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
fastlambertw average relative error = 5.26867e-05
fastlambertw max relative error (at 2.48955e-06) = 0.0631815
fasterlambertw average relative error = 0.00798678
fasterlambertw max relative error (at -0.00122776) = 0.926378
vfastlambertw average relative error = 5.42952e-05
vfastlambertw max relative error (at -2.78399e-06) = 0.0661513
vfasterlambertw average relative error = 0.00783347
vfasterlambertw max relative error (at -0.00125244) = 0.926431
gsl_sf_lambert_W0 average relative error = 5.90309e-09
gsl_sf_lambert_W0 max relative error (at -0.36782) = 6.67586e-07
fastlambertw million calls per second = 21.2236
fasterlambertw million calls per second = 53.4428
vfastlambertw million calls per second = 21.6723
vfasterlambertw million calls per second = 56.0154
gsl_sf_lambert_W0 million calls per second = 2.7433
These average accuracies hide the relative poor performance at the minimum of the domain. Right at x=e1, which is the minimum of the domain, the fastlambertw approximation is poor (-0.938, whereas the correct answer is -1; so relative error of 6%); but at x=e1+1100, the relative error drops to 2×104.

1 comment:

  1. Thanks, this was useful. I ran into a use for Lambert W in a machine learning context, although also one where W(exp(x)) would have been more useful.

    As you hint, exponentiating y=exp(x) for large x only to compute W(y) risks overflow only to effectively take (something close to) the logarithm again afterwards.

    Your initialisation method is easily adapted to W(exp(x)) though, cheers!

    ReplyDelete