In the meantime, although it is a bit pokey the reduction I'm presenting here is still practical. In addition, sometimes just seeing an implementation of something can really crystallize the concepts, so I thought I'd present the reduction here.
The Strategy
The starting point is the Filter Tree reduction of cost-sensitive multiclass classification to importance weighted binary classification. In this reduction, class labels are arranged into a March-madness style tournament, with winners playing winners until one class label emerges victorious: that is the resulting prediction. When two class labels ``play each other'', what really happens is an importance weighted classifier decides who wins based upon the associated instance features $x$.In practice I'm using a particular kind of filter tree which I call a scoring filter tree. Here the importance weighted classifier is constrained to be of the form \[
\Psi_{\nu} (x) = 1_{f (x; \lambda) > f (x; \phi)}.
\] Here $\lambda$ and $\phi$ are the two class labels who are ``playing each other'' to see who advances in the tournament. What this equation says is:
- There is a function $f$ which says how good each class label is given the instance features $x$.
- The better class label always beats the other class label.
The Implementation
I'll assume that we're trying to classify between $|K|$ labels denoted by integers $\{ 1, \ldots, |K|\}$. I'll also assume an input format which is very close to vowpal's native input format, but with a cost vector instead of a label. \[c_1,\ldots,c_{|K|}\; \textrm{importance}\; \textrm{tag}|\textrm{namespace}\; \textrm{feature} \ldots
\] So for instance a 3 class problem input line might look like \[
0.7,0.2,1.3\; 0.6\; \textrm{idiocracy}|\textrm{items}\; \textrm{hotlatte}\; |\textrm{desires}\; \textrm{i}\; \textrm{like}\; \textrm{money}
\] The best choice (lowest cost) class here is 2.
Test Time
Applying the model is easier to understand than training it, so I'll start there. Within the perl I transform this into a set of vowpal input lines where each line corresponds to a particular class label $n$, \[\; \textrm{tag}|\textrm{namespace}n\; \textrm{feature} \ldots
\] Essentially the cost vector and importance weight are stripped out (since there is no learning happening right now), the tag is stripped out (I handle that separately), and each namespace has the class label appended to it. Since vowpal uses the first letter to identify namespaces, options that operate on namespaces (e.g., -q, --ignore) will continue to work as expected. So for instance continuing with the above example we would generate three lines \[
|\textrm{items}1\; \textrm{hotlatte}\; |\textrm{desires}1\; \textrm{i}\; \textrm{like}\; \textrm{money}\; |\_1\; k
\] \[
|\textrm{items}2\; \textrm{hotlatte}\; |\textrm{desires}2\; \textrm{i}\; \textrm{like}\; \textrm{money}\; |\_2\; k
\] \[
|\textrm{items}3\; \textrm{hotlatte}\; |\textrm{desires}3\; \textrm{i}\; \textrm{like}\; \textrm{money}\; |\_3\; k
\] Each of these lines is fed to vowpal, and the class label that has the lowest vowpal output is selected as the winner of the tournament. That last feature $k$ in the namespace _ is providing a class label localized version of the constant feature that vowpal silently provides on every example.
Train Time
At train time I essentially run the tournament: but since I know the actual costs, I update the classifier based upon who ``should have won''. The importance weight of the update is determined by the absolute difference in costs between the two teams that just played. So in the case of two class labels $i$ and $j$ there will be a training input fed to vowpal of the form, \[1\; \omega\; \textrm{tag}|\textrm{namespace$i$:1}\; \textrm{feature} \ldots |\textrm{namespace$j$:-1}\; \textrm{feature} \ldots |\textrm{\_$i$:-1} \; k\; |\textrm{\_$j$:-1}\; k
\] where $\omega = \textrm{importance} * \mbox{abs} (c_i - c_j)$, i.e., the original importance weight scaled by the absolute difference in the actual costs. Here I'm leveraging the ability to provide a weight on a namespace which multiplies the weights on all the features in the namespace. (What about that pesky constant feature that vowpal always provides? It's still there and really it shouldn't be. The right way to deal with this would be to patch vowpal to accept an option not to provide the constant feature. However I want to present something that works with an unpatched vowpal, so instead I feed another training input with everything negated in order to force the constant feature to stay near zero.)
When a team wins a game they should not have won, they still advance in the tournament. Intuitively, the classifier needs to recover gracefully from mistakes made previously in the tournament, so this is the right thing to do.
What's Missing
Here are some things I'd like to improve:- Implement inside vowpal instead of outside via IPC.
- In the implementation I manually design the tournament based upon a particular number of classes. It would be better to automatically construct the tournament.
- It would be nice to have a concise way to specify sparse cost-vectors. For example when all errors are equally bad all that is needed is the identity of the correct label.
- The above strategy doesn't work with hinge loss, and I don't know why (it appears to work with squared and logistic loss). Probably I've made a mistake somewhere. Caveat emptor!
The Code
There are two pieces:- vowpal.pm: this encapsulates the communication with vowpal. You'll need this to get it to work, but mostly this boring unix IPC stuff.
- It's not very good at detecting that the underlying vw did not start successfully (e.g., due to attempting to load a model that does not exist). However you will notice this since it just hangs.
- filter-tree: perl script where the reduction implementation actually lives. You invoke this to get going. Mostly it takes the same arguments as vw itself and just passes them through, with some exceptions:
- You have to read data from standard input. I could intercept --data arguments and emulate them, but I don't.
- You can't use the --passes argument because of the previous statement.
- I do intercept the -p argument (for outputting predictions) and emulate this at the reduction level.
The output you see from filter-tree looks like the output from vw, but it not. It's actually from the perl script, and is designed to look like vw output suitably modified for the multiclass case.
Here's an example invocation:
% zcat traindata.gz | head -1000 | ./filter-tree --adaptive -l 1 -b 22 --loss_function logistic -f model.users.b22 average since example example current current current loss last counter weight label predict features 1.000000 1.000000 1 1.0 1.0000 0.0000 16 0.500000 0.000000 2 2.0 1.0000 1.0000 15 0.500000 0.500000 4 4.0 2.0000 1.0000 20 0.375000 0.250000 8 8.0 2.0000 2.0000 19 0.562500 0.750000 16 16.0 5.0000 2.0000 23 0.437500 0.312500 32 32.0 0.0000 1.0000 14 0.281250 0.125000 64 64.0 1.0000 1.0000 16 0.312500 0.343750 128 128.0 0.0000 1.0000 16 0.347656 0.382812 256 256.0 1.0000 1.0000 13 0.322266 0.296875 512 512.0 1.0000 1.0000 20 finished run number of examples = 1000 weighted examples sum = 1000 average cost-sensitive loss = 0.287 average classification loss = 0.287 best constant for cost-sensitive = 1 best constant cost-sensitive loss = 0.542 best constant for classification = 1 best constant classification loss = 0.542 minimum possible loss = 0.000 confusion matrix 15 1 0 1 0 1 0 77 416 53 23 5 0 1 14 41 281 56 8 3 2 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0The -p argument outputs a tab separated set of columns. The first column is the predicted class label, the next $|K|$ columns are the scoring function values per class label, and the last column is the instance tag.
As is typical, the source code is (unfortunately) the best documentation.
filter-tree
#! /usr/bin/env perl use warnings; use strict; use vowpal; $SIG{INT} = sub { die "caught SIGINT"; }; # if this looks stupid it is: these used to be actual class names, # but i didn't want to release code with the actual class labels that i'm using use constant { ZERO => 0, ONE => 1, TWO => 2, THREE => 3, FOUR => 4, FIVE => 5, SIX => 6, }; sub argmin (@) { my (@list) = @_; my $argmin = 0; foreach my $x (1 .. $#list) { if ($list[$x] < $list[$argmin]) { $argmin = $x; } } return $argmin; } sub tweak_line ($$) { my ($suffix, $rest) = @_; $rest =~ s/\|(\S*)/\|${1}${suffix}/g; return $rest; } sub train_node ($$$$$$$$$) { my ($m, $la, $lb, $pa, $pb, $ca, $cb, $i, $rest) = @_; my $argmin = ($ca < $cb) ? -1 : 1; my $absdiff = abs ($ca - $cb); if ($absdiff > 0) { chomp $rest; my $w = $i * $absdiff; my $plusone = 1; my $minusone = -1; my $chirp = (rand () < 0.5) ? 1 : -1; $argmin *= $chirp; $plusone *= $chirp; $minusone *= $chirp; $m->send ("$argmin $w", tweak_line ("${la}:$plusone", " |$rest |_ k"), tweak_line ("${lb}:$minusone", " |$rest |_ k\n"))->recv () or die "vowpal failed to respond"; $argmin *= -1; $plusone *= -1; $minusone *= -1; $m->send ("$argmin $w", tweak_line ("${la}:$plusone", " |$rest |_ k"), tweak_line ("${lb}:$minusone", " |$rest |_ k\n"))->recv () or die "vowpal failed to respond"; } return $pa - $pb; } sub print_update ($$$$$$$$) { my ($total_loss, $since_last, $delta_weight, $example_counter, $example_weight, $current_label, $current_predict, $current_features) = @_; printf STDERR "%-10.6f %-10.6f %8lld %8.1f %s %8.4f %8lu\n", $example_weight > 0 ? $total_loss / $example_weight : -1, $delta_weight > 0 ? $since_last / $delta_weight : -1, $example_counter, $example_weight, defined ($current_label) ? sprintf ("%8.4f", $current_label) : " unknown", $current_predict, $current_features; } #--------------------------------------------------------------------- # main #--------------------------------------------------------------------- srand 69; my @my_argv; my $pred_fh; while (@ARGV) { my $arg = shift @ARGV; last if $arg eq '--'; if ($arg eq "-p") { my $pred_file = shift @ARGV or die "-p argument missing"; $pred_fh = new IO::File $pred_file, "w" or die "$pred_file: $!"; } else { push @my_argv, $arg; } } my $model = new vowpal join " ", @my_argv; print STDERR <<EOD; average since example example current current current loss last counter weight label predict features EOD my $total_loss = 0; my $since_last = 0; my $example_counter = 0; my $example_weight = 0; my $delta_weight = 0; my $dump_interval = 1; my @best_constant_loss = map { 0 } (ZERO .. SIX); my @best_constant_classification_loss = map { 0 } (ZERO .. SIX); my $minimum_possible_loss = 0; my $classification_loss = 0; my $mismatch = 0; my %confusion; while (defined ($_ = <>)) { my ($preline, $rest) = split /\|/, $_, 2; die "bad preline $preline" unless $preline =~ /^([\d\.]+)?\s+([\d\.]+\s+)?(\S+)?$/; my $label = $1; my $importance = $2 ? $2 : 1; my $tag = $3; my ($actual_tag, @costs) = split /,/, $tag; die "bad tag $tag" unless @costs == 0 || @costs == 8; my @scores = map { my $s = $model->send (tweak_line ($_, " |$rest |_ k"))->recv (); chomp $s; $s } (ZERO .. SIX); my $current_prediction = argmin @scores; if (@costs == 8) { # it turned out better for my problem to combine classes 6 and 7. # costs are already inverted and subtracted from 1, so, # have to subtract 1 when doing this my $class_seven = pop @costs; $costs[SIX] += $class_seven - 1; # zero level my $zero_one = train_node ($model, ZERO, ONE, $scores[ZERO], $scores[ONE], $costs[ZERO], $costs[ONE], $importance, $rest) <= 0 ? ZERO : ONE; my $two_three = train_node ($model, TWO, THREE, $scores[TWO], $scores[THREE], $costs[TWO], $costs[THREE], $importance, $rest) <= 0 ? TWO : THREE; my $four_five = train_node ($model, FOUR, FIVE, $scores[FOUR], $scores[FIVE], $costs[FOUR], $costs[FIVE], $importance, $rest) <= 0 ? FOUR : FIVE; # SIX gets a pass # first level: (zero_one vs. two_three), (four_five vs. SIX) my $fleft = train_node ($model, $zero_one, $two_three, $scores[$zero_one], $scores[$two_three], $costs[$zero_one], $costs[$two_three], $importance, $rest) <= 0 ? $zero_one : $two_three; my $fright = train_node ($model, $four_five, SIX, $scores[$four_five], $scores[SIX], $costs[$four_five], $costs[SIX], $importance, $rest) <= 0 ? $four_five : SIX; # second level: fleft vs. fright my $root = train_node ($model, $fleft, $fright, $scores[$fleft], $scores[$fright], $costs[$fleft], $costs[$fright], $importance, $rest) <= 0 ? $fleft : $fright; $total_loss += $importance * $costs[$root]; $since_last += $importance * $costs[$root]; $example_weight += $importance; $delta_weight += $importance; my $best_prediction = argmin @costs; foreach my $c (ZERO .. SIX) { $best_constant_loss[$c] += $importance * $costs[$c]; if ($c != $best_prediction) { $best_constant_classification_loss[$c] += $importance; } } $minimum_possible_loss += $importance * $costs[$best_prediction]; $classification_loss += ($current_prediction == $best_prediction) ? 0 : 1; ++$confusion{"$current_prediction:$best_prediction"}; ++$mismatch if $root ne $current_prediction; } print $pred_fh (join "\t", $current_prediction, @scores, $actual_tag), "\n" if defined $pred_fh; ++$example_counter; if ($example_counter >= $dump_interval) { my @features = split /\s+/, $rest; # TODO: not really print_update ($total_loss, $since_last, $delta_weight, $example_counter, $example_weight, (@costs) ? (argmin @costs) : undef, $current_prediction, scalar @features); $dump_interval *= 2; $since_last = 0; $delta_weight = 0; } } my $average_loss = sprintf "%.3f", $example_weight > 0 ? $total_loss / $example_weight : -1; my $best_constant = argmin @best_constant_loss; my $best_constant_average_loss = sprintf "%.3f", $example_weight > 0 ? $best_constant_loss[$best_constant] / $example_weight : -1; my $best_constant_classification = argmin @best_constant_classification_loss; my $best_constant_classification_average_loss = sprintf "%.3f", $example_weight > 0 ? $best_constant_classification_loss[$best_constant_classification] / $example_weight : -1; my $minimum_possible_average_loss = sprintf "%.3f", $example_weight > 0 ? $minimum_possible_loss / $example_weight : -1; my $classification_average_loss = sprintf "%.3f", $example_weight > 0 ? $classification_loss / $example_weight : -1; print <<EOD; finished run number of examples = $example_counter weighted examples sum = $example_weight average cost-sensitive loss = $average_loss average classification loss = $classification_average_loss best constant for cost-sensitive = $best_constant best constant cost-sensitive loss = $best_constant_average_loss best constant for classification = $best_constant_classification best constant classification loss = $best_constant_classification_average_loss minimum possible loss = $minimum_possible_average_loss confusion matrix EOD #train/test mismatch = $mismatch foreach my $pred (ZERO .. SIX) { print join "\t", map { $confusion{"$pred:$_"} || 0 } (ZERO .. SIX); print "\n"; }
vowpal.pm
# vowpal.pm package vowpal; use warnings; use strict; use POSIX qw (tmpnam mkfifo); use IO::File; use IO::Pipe; use IO::Poll; sub new ($$) { my $class = shift; my $args = shift; my $pred_pipename = tmpnam () or die $!; my $pred_pipe = mkfifo ($pred_pipename, 0700) or die $!; my $pred_fd = POSIX::open ($pred_pipename, &POSIX::O_RDONLY | &POSIX::O_NONBLOCK | &POSIX::O_NOCTTY) or die $!; my $pred_fh = new IO::Handle; $pred_fh->fdopen ($pred_fd, "r") or die $!; POSIX::fcntl ($pred_fh, &POSIX::F_SETFL, POSIX::fcntl ($pred_fh, &POSIX::F_GETFL, 0) & ~&POSIX::O_NONBLOCK); my $data_fh = new IO::Pipe or die $!; open my $oldout, ">&STDOUT" or die "Can't dup STDOUT: $!"; eval { open STDOUT, ">", "/dev/null" or die "Can't redirect STDOUT: $!"; eval { open my $olderr, ">&STDERR" or die "Can't dup STDERR: $!"; eval { open STDERR, ">", "/dev/null" or die "Can't redirect STDERR: $!"; $data_fh->writer ("vw $args -p $pred_pipename --quiet") or die $!; $data_fh->autoflush (1); }; open STDERR, ">&", $olderr or die "Can't restore STDERR: $!"; die $@ if $@; }; open STDOUT, ">&", $oldout or die "Can't restore STDOUT: $!"; die $@ if $@; }; die $@ if $@; my $poll = new IO::Poll; $poll->mask ($data_fh => POLLOUT); $poll->poll (); $poll->remove ($data_fh); $poll->mask ($pred_fh => POLLIN); my $self = { data_fh => $data_fh, pred_fh => $pred_fh, pred_file => $pred_pipename, poll => $poll, args => $args }; bless $self, $class; return $self; } sub send ($@) { my $self = shift; $self->{'data_fh'}->print (@_); return $self; } sub recv ($) { my $self = shift; $self->{'poll'}->poll (); return $self->{'pred_fh'}->getline (); } sub DESTROY { my $self = shift; $self->{'data_fh'}->close (); $self->{'pred_fh'}->close (); unlink $self->{'pred_file'}; } 1;