First consider the two class case. Imagine looking for features of the form $\phi (w^\top x)$, where $w \in \mathbb{R}^d$ is a “weight vector” and $\phi$ is some nonlinearity. What is a simple criterion for defining a good feature? One idea is for the feature to have small average value on one class and large average value on another. Assuming $\phi$ is non-negative, that suggests maximizing the ratio \[
w^* = \arg \max_w \frac{\mathbb{E}[\phi (w^\top x) | y = 1]}{\mathbb{E}[\phi (w^\top x) | y = 0]}.
\] For the specific choice of $\phi (z) = z^2$ this is tractable, as it results in a Rayleigh quotient between two class-conditional second moments, \[
w^* = \arg \max_w \frac{w^\top \mathbb{E}[x x^\top | y = 1] w}{w^\top \mathbb{E}[x x^\top | y = 0] w},
\] which can be solved via generalized eigenvalue decomposition. Generalized eigenvalue problems have been extensively studied in machine learning and elsewhere, and the above idea looks very similar to many other proposals (e.g., Fisher LDA), but it is different and more empirically effective. I'll refer you to the paper for a more thorough discussion, but I will mention that after the paper was accepted someone pointed out the similarity to CSP, which is a technique from time-series analysis (c.f., Ecclesiastes 1:4-11).
The features that result from this procedure pass the smell test. For example, starting from a raw pixel representation on mnist, the weight vectors can be visualized as images; the first weight vector for discriminating 3 vs. 2 looks like
which looks like a pen stroke, c.f., figure 1D of Ranzato et. al.
We make several additional observations in the paper. The first is that multiple isolated minima of the Rayleigh quotient are useful if the associated generalized eigenvalues are large, i.e., one can extract multiple features from a Rayleigh quotient. The second is that, for moderate $k$, we can extract features for each class pair independently and use all the resulting features to get good results. The third is that the resulting directions have additional structure which is not completely captured by a squaring non-linearity, which motivates a (univariate) basis function expansion. The fourth is that, once the original representation has been augmented with additional features, the procedure can be repeated, which sometimes yields additional improvements. Finally, we can compose this with randomized feature maps to approximate the corresponding operations in a RKHS, which sometimes yields additional improvements. We also made a throw-away comment in the paper that computing class-conditional second moment matrices is easily done in a map-reduce style distributed framework, but this was actually a major motivation for us to explore in this direction, it just didn't fit well into the exposition of the paper so we de-emphasized it.
Combining the above ideas, along with Nikos' preconditioned gradient learning for multiclass described in a previous post, leads to the following Matlab script, which gets 91 test errors on (permutation invariant) mnist. Note: you'll need to download mnist_all.mat from Sam Roweis' site to run this.
function calgevsquared more off; clear all; close all; start=tic; load('mnist_all.mat'); xxt=[train0; train1; train2; train3; train4; train5; ... train6; train7; train8; train9]; xxs=[test0; test1; test2; test3; test4; test5; test6; test7; test8; test9]; kt=single(xxt)/255; ks=single(xxs)/255; st=[size(train0,1); size(train1,1); size(train2,1); size(train3,1); ... size(train4,1); size(train5,1); size(train6,1); size(train7,1); ... size(train8,1); size(train9,1)]; ss=[size(test0,1); size(test1,1); size(test2,1); size(test3,1); ... size(test4,1); size(test5,1); size(test6,1); size(test7,1); ... size(test8,1); size(test9,1)]; paren = @(x, varargin) x(varargin{:}); yt=zeros(60000,10); ys=zeros(10000,10); I10=eye(10); lst=1; for i=1:10; yt(lst:lst+st(i)-1,:)=repmat(I10(i,:),st(i),1); lst=lst+st(i); end lst=1; for i=1:10; ys(lst:lst+ss(i)-1,:)=repmat(I10(i,:),ss(i),1); lst=lst+ss(i); end clear i st ss lst clear xxt xxs clear train0 train1 train2 train3 train4 train5 train6 train7 train8 train9 clear test0 test1 test2 test3 test4 test5 test6 test7 test8 test9 [n,k]=size(yt); [m,d]=size(ks); gamma=0.1; top=20; for i=1:k ind=find(yt(:,i)==1); kind=kt(ind,:); ni=length(ind); covs(:,:,i)=double(kind'*kind)/ni; clear ind kind; end filters=zeros(d,top*k*(k-1),'single'); last=0; threshold=0; for j=1:k covj=squeeze(covs(:,:,j)); l=chol(covj+gamma*eye(d))'; for i=1:k if j~=i covi=squeeze(covs(:,:,i)); C=l\covi/l'; CS=0.5*(C+C'); [v,L]=eigs(CS,top); V=l'\v; take=find(diag(L)>=threshold); batch=length(take); fprintf('%u,%u,%u ', i, j, batch); filters(:,last+1:last+batch)=V(:,take); last=last+batch; end end fprintf('\n'); end clear covi covj covs C CS V v L % NB: augmenting kt/ks with .^2 terms is very slow and doesn't help filters=filters(:,1:last); ft=kt*filters; clear kt; kt=[ones(n,1,'single') sqrt(1+max(ft,0))-1 sqrt(1+max(-ft,0))-1]; clear ft; fs=ks*filters; clear ks filters; ks=[ones(m,1,'single') sqrt(1+max(fs,0))-1 sqrt(1+max(-fs,0))-1]; clear fs; [n,k]=size(yt); [m,d]=size(ks); for i=1:k ind=find(yt(:,i)==1); kind=kt(ind,:); ni=length(ind); covs(:,:,i)=double(kind'*kind)/ni; clear ind kind; end filters=zeros(d,top*k*(k-1),'single'); last=0; threshold=7.5; for j=1:k covj=squeeze(covs(:,:,j)); l=chol(covj+gamma*eye(d))'; for i=1:k if j~=i covi=squeeze(covs(:,:,i)); C=l\covi/l'; CS=0.5*(C+C'); [v,L]=eigs(CS,top); V=l'\v; take=find(diag(L)>=threshold); batch=length(take); fprintf('%u,%u,%u ', i, j, batch); filters(:,last+1:last+batch)=V(:,take); last=last+batch; end end fprintf('\n'); end fprintf('gamma=%g,top=%u,threshold=%g\n',gamma,top,threshold); fprintf('last=%u filtered=%u\n', last, size(filters,2) - last); clear covi covj covs C CS V v L filters=filters(:,1:last); ft=kt*filters; clear kt; kt=[sqrt(1+max(ft,0))-1 sqrt(1+max(-ft,0))-1]; clear ft; fs=ks*filters; clear ks filters; ks=[sqrt(1+max(fs,0))-1 sqrt(1+max(-fs,0))-1]; clear fs; trainx=[ones(n,1,'single') kt kt.^2]; clear kt; testx=[ones(m,1,'single') ks ks.^2]; clear ks; C=chol(0.5*(trainx'*trainx)+sqrt(n)*eye(size(trainx,2)),'lower'); w=C'\(C\(trainx'*yt)); pt=trainx*w; ps=testx*w; [~,trainy]=max(yt,[],2); [~,testy]=max(ys,[],2); for i=1:5 xn=[pt pt.^2/2 pt.^3/6 pt.^4/24]; xm=[ps ps.^2/2 ps.^3/6 ps.^4/24]; c=chol(xn'*xn+sqrt(n)*eye(size(xn,2)),'lower'); ww=c'\(c\(xn'*yt)); ppt=SimplexProj(xn*ww); pps=SimplexProj(xm*ww); w=C'\(C\(trainx'*(yt-ppt))); pt=ppt+trainx*w; ps=pps+testx*w; [~,yhatt]=max(pt,[],2); [~,yhats]=max(ps,[],2); errort=sum(yhatt~=trainy)/n; errors=sum(yhats~=testy)/m; fprintf('%u,%g,%g\n',i,errort,errors) end fprintf('%4s\t', 'pred'); for true=1:k fprintf('%5u', true-1); end fprintf('%5s\n%4s\n', '!=', 'true'); for true=1:k fprintf('%4u\t', true-1); trueidx=find(testy==true); for predicted=1:k predidx=find(yhats(trueidx)==predicted); fprintf('%5u', sum(predidx>0)); end predidx=find(yhats(trueidx)~=true); fprintf('%5u\n', sum(predidx>0)); end toc(start) end % http://arxiv.org/pdf/1309.1541v1.pdf function X = SimplexProj(Y) [N,D] = size(Y); X = sort(Y,2,'descend'); Xtmp = bsxfun(@times,cumsum(X,2)-1,(1./(1:D))); X = max(bsxfun(@minus,Y,Xtmp(sub2ind([N,D],(1:N)',sum(X>Xtmp,2)))),0); endWhen I run this on my desktop machine it yields
>> calgevsquared 2,1,20 3,1,20 4,1,20 5,1,20 6,1,20 7,1,20 8,1,20 9,1,20 10,1,20 1,2,20 3,2,20 4,2,20 5,2,20 6,2,20 7,2,20 8,2,20 9,2,20 10,2,20 1,3,20 2,3,20 4,3,20 5,3,20 6,3,20 7,3,20 8,3,20 9,3,20 10,3,20 1,4,20 2,4,20 3,4,20 5,4,20 6,4,20 7,4,20 8,4,20 9,4,20 10,4,20 1,5,20 2,5,20 3,5,20 4,5,20 6,5,20 7,5,20 8,5,20 9,5,20 10,5,20 1,6,20 2,6,20 3,6,20 4,6,20 5,6,20 7,6,20 8,6,20 9,6,20 10,6,20 1,7,20 2,7,20 3,7,20 4,7,20 5,7,20 6,7,20 8,7,20 9,7,20 10,7,20 1,8,20 2,8,20 3,8,20 4,8,20 5,8,20 6,8,20 7,8,20 9,8,20 10,8,20 1,9,20 2,9,20 3,9,20 4,9,20 5,9,20 6,9,20 7,9,20 8,9,20 10,9,20 1,10,20 2,10,20 3,10,20 4,10,20 5,10,20 6,10,20 7,10,20 8,10,20 9,10,20 2,1,15 3,1,20 4,1,20 5,1,20 6,1,20 7,1,20 8,1,20 9,1,20 10,1,20 1,2,20 3,2,20 4,2,20 5,2,20 6,2,20 7,2,20 8,2,20 9,2,20 10,2,20 1,3,20 2,3,11 4,3,17 5,3,20 6,3,20 7,3,19 8,3,18 9,3,18 10,3,19 1,4,20 2,4,12 3,4,20 5,4,20 6,4,12 7,4,20 8,4,19 9,4,15 10,4,20 1,5,20 2,5,12 3,5,20 4,5,20 6,5,20 7,5,20 8,5,16 9,5,20 10,5,9 1,6,18 2,6,13 3,6,20 4,6,12 5,6,20 7,6,18 8,6,20 9,6,13 10,6,18 1,7,20 2,7,14 3,7,20 4,7,20 5,7,20 6,7,20 8,7,20 9,7,20 10,7,20 1,8,20 2,8,14 3,8,20 4,8,20 5,8,20 6,8,20 7,8,20 9,8,20 10,8,12 1,9,20 2,9,9 3,9,20 4,9,15 5,9,18 6,9,11 7,9,20 8,9,17 10,9,16 1,10,20 2,10,14 3,10,20 4,10,20 5,10,14 6,10,20 7,10,20 8,10,12 9,10,20 gamma=0.1,top=20,threshold=7.5 last=1630 filtered=170 1,0.0035,0.0097 2,0.00263333,0.0096 3,0.00191667,0.0092 4,0.00156667,0.0093 5,0.00141667,0.0091 pred 0 1 2 3 4 5 6 7 8 9 != true 0 977 0 1 0 0 1 0 1 0 0 3 1 0 1129 2 1 0 0 1 1 1 0 6 2 1 1 1020 0 1 0 0 6 3 0 12 3 0 0 1 1004 0 1 0 2 1 1 6 4 0 0 0 0 972 0 4 0 2 4 10 5 1 0 0 5 0 883 2 1 0 0 9 6 4 2 0 0 2 2 947 0 1 0 11 7 0 2 5 0 0 0 0 1018 1 2 10 8 1 0 1 1 1 1 0 1 966 2 8 9 1 1 0 2 5 2 0 4 1 993 16 Elapsed time is 186.147659 seconds.That's a pretty good confusion matrix, comparable to state-of-the-art deep learning results on (permutation invariant) mnist. In the paper we report a slightly worse number (96 test errors) because for a paper we have to choose hyperparameters via cross-validation on the training set rather than cherry-pick them as for a blog post.
The technique as stated here is really only useful for tall-thin design matrices (i.e., lots of examples but not too many features): if the original feature dimensionality is too large (e.g., $> 10^4$) than naive use of standard generalized eigensolvers becomes slow or infeasible, and other tricks are required. Furthermore, if the number of classes is too large than solving $O (k^2)$ generalized eigenvalue problems is also not reasonable. We're working on remedying these issues, and we're also excited about extending this strategy to structured prediction. Hopefully we'll have more to say about it at the next few conferences.