Faster Log Gamma and Digamma
In addition to lots of calls to exponential and logarithm, for which I have already have fast approximations, there are also calls to Log Gamma and Digamma. Fortunately the leading term in Stirling's approximation is a logarithm so I can leverage my fast logarithm approximation from my previous blog post.For both Log Gamma and Digamma I found applying the shift formula twice gave the best tradeoff between accuracy and speed.
// WARNING: this code has been updated. Do not use this version. // Instead get the latest version from http://code.google.com/p/fastapprox inline float fastlgamma (float x) { float logterm = fastlog (x * (1.0f + x) * (2.0f + x)); float xp3 = 3.0f + x; return -2.081061466f - x + 0.0833333f / xp3 - logterm + (2.5f + x) * fastlog (xp3); } inline float fastdigamma (float x) { float twopx = 2.0f + x; float logterm = fastlog (twopx); return - (1.0f + 2.0f * x) / (x * (1.0f + x)) - (13.0f + 6.0f * x) / (12.0f * twopx * twopx) + logterm; }
Timing and Accuracy
Timing tests are done by compiling with -O3 -finline-functions -ffast-math, on a box running 64 bit Ubuntu Lucid (so gcc 4:4.4.3-1ubuntu1 and libc 2.11.1-0ubuntu7.6). I also measured average relative accuracy for $x \in [1/100, 10]$ (using lgammaf from libc as the gold standard for Log Gamma, and boost::math::digamma as the gold standard for Digamma).fastlgamma relative accuracy = 0.00045967 fastdigamma relative accuracy = 0.000420604 fastlgamma million calls per second = 60.259 lgammaf million calls per second = 21.4703 boost::math::lgamma million calls per second = 1.8951 fastdigamma million calls per second = 66.0639 boost::math::digamma million calls per second = 18.9672Well the speedup is roughly 3x for Log Gamma (but why is boost::math::lgamma so slow?) and slightly better than 3x for Digamma.
LDA Timings
So I dropped these approximations into lda.cc and did some timing tests, using the wiki1K file from the test/train-sets/ directory; but I concatenated the file with itself 100 times to slow things down.Original Run
% (cd test; time ../vw --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -b 13 --minibatch 128 train-sets/wiki1Kx100.dat) your learning rate is too high, setting it to 1 using no cache Reading from train-sets/wiki1Kx100.dat num sources = 1 Num weight bits = 13 learning rate = 1 initial_t = 1 power_t = 0.5 learning_rate set to 1 average since example example current current current loss last counter weight label predict features 10.296875 10.296875 3 3.0 unknown 0.0000 38 10.437156 10.577436 6 6.0 unknown 0.0000 14 10.347227 10.239314 11 11.0 unknown 0.0000 32 10.498633 10.650038 22 22.0 unknown 0.0000 2 10.495566 10.492500 44 44.0 unknown 0.0000 166 10.469184 10.442189 87 87.0 unknown 0.0000 29 10.068007 9.666831 174 174.0 unknown 0.0000 17 9.477440 8.886873 348 348.0 unknown 0.0000 2 9.020482 8.563524 696 696.0 unknown 0.0000 143 8.482232 7.943982 1392 1392.0 unknown 0.0000 31 8.106654 7.731076 2784 2784.0 unknown 0.0000 21 7.850269 7.593883 5568 5568.0 unknown 0.0000 25 7.707427 7.564560 11135 11135.0 unknown 0.0000 39 7.627417 7.547399 22269 22269.0 unknown 0.0000 61 7.583820 7.540222 44537 44537.0 unknown 0.0000 5 7.558824 7.533827 89073 89073.0 unknown 0.0000 457 finished run number of examples = 0 weighted example sum = 0 weighted label sum = 0 average loss = -nan best constant = -nan total feature number = 0 ../vw --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -b 13 --minibatc 69.06s user 13.60s system 101% cpu 1:21.59 total
Approximate Run
% (cd test; time ../vw --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -b 13 --minibatch 128 train-sets/wiki1Kx100.dat) your learning rate is too high, setting it to 1 using no cache Reading from train-sets/wiki1Kx100.dat num sources = 1 Num weight bits = 13 learning rate = 1 initial_t = 1 power_t = 0.5 learning_rate set to 1 average since example example current current current loss last counter weight label predict features 10.297077 10.297077 3 3.0 unknown 0.0000 38 10.437259 10.577440 6 6.0 unknown 0.0000 14 10.347348 10.239455 11 11.0 unknown 0.0000 32 10.498796 10.650243 22 22.0 unknown 0.0000 2 10.495748 10.492700 44 44.0 unknown 0.0000 166 10.469374 10.442388 87 87.0 unknown 0.0000 29 10.068179 9.666985 174 174.0 unknown 0.0000 17 9.477574 8.886968 348 348.0 unknown 0.0000 2 9.020435 8.563297 696 696.0 unknown 0.0000 143 8.482178 7.943921 1392 1392.0 unknown 0.0000 31 8.106636 7.731093 2784 2784.0 unknown 0.0000 21 7.849980 7.593324 5568 5568.0 unknown 0.0000 25 7.707124 7.564242 11135 11135.0 unknown 0.0000 39 7.627207 7.547283 22269 22269.0 unknown 0.0000 61 7.583952 7.540696 44537 44537.0 unknown 0.0000 5 7.559260 7.534566 89073 89073.0 unknown 0.0000 457 finished run number of examples = 0 weighted example sum = 0 weighted label sum = 0 average loss = -nan best constant = -nan total feature number = 0 ../vw --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -b 13 --minibatc 41.97s user 7.69s system 102% cpu 48.622 totalSo from 81 seconds down to 49 seconds, a roughly 40% speed increase. But does it work? Well I'll let Matt be the final judge of that (I've sent him a modified lda.cc for testing), but initial indications are that using approximate versions of the expensive functions doesn't spoil the learning process.
Next Step: Vectorization
There is definitely more speed to be had, since lda.cc is loaded with code likefor (size_t i = 0; i<global.lda; i++) { sum += gamma[i]; gamma[i] = mydigamma(gamma[i]); }which just screams for vectorization. Stay tuned!
No comments:
Post a Comment