Since I'm using multinomial logistic regression inside playerpiano I was curious if there was an importance-aware update for it. The loss function I'm using is cross-entropy between a target probability vector $q$ and predicted probability vector $p$ computed from weights $w$ and input features $x$, \[
\begin{aligned}
l (x, w, q) &= \sum_{j \in J} q_j \log p_j (x, w), \\
p_k (x, w) &= \frac{\exp (x^\top w_k)}{\sum_{j \in J} \exp (x^\top w_j)}, \\
w_0 &= 0.
\end{aligned}
\] In general an importance-aware update is derived by integrating the gradient dynamics of the instantaneous loss function, for which the usual SGD update step can be seen as a first-order Euler approximation. For $j > 0$, gradient dynamics for the weights are \[
\begin{aligned}
\frac{d w_j (t)}{d t} &= \frac{\partial l (x, w (t), q)}{\partial w_j (t)} \\
&= \bigl( q_j - p_j (x, w (t)) \bigr) x.
\end{aligned}
\] Happily all the gradients point in the same direction, so I will look for a solution of the form $w_j (t) = w_j + s_j (t) x$, yielding \[
\begin{aligned}
\frac{d s_j (t)}{d t} &= q_j - \tilde p_j (x, w, s (t)), \\
\tilde p_k (x, w, s) &= \frac{\exp (x^\top w_k + s_k x^\top x)}{\sum_{j \in J} \exp (x^\top w_j + s_j x^\top x)} \\
&= \frac{p_k (x, w) \exp (s_k x^\top x)}{\sum_{j \in J} p_j (x, w) \exp (s_j x^\top x)}, \\
s_j (0) &= 0.
\end{aligned}
\] I'm unable to make analytic progress past this point. However this now looks like a $(|J|-1)$ dimensional ODE whose right-hand side can be calculated in $O (|J|)$ since $p$ and $x^\top x$ can be memoized. Thus in practice this can be numerically integrated without significant overhead (I'm only seeing about a 10% overall slowdown in playerpiano). There is a similar trick for Polytomous Rasch for the ordinal case.
I get improved results even on data sets where all the importance weights are 1. It's not an earth-shattering lift but I do see a consistent mild improvement in generalization error on several problems. I suspect that if I exhaustively searched the space of learning parameters (initial learning rate $\eta$ and power law decay $\rho$) I could find settings to achieve the lift without an importance-aware update. However that's one of the benefits of the importance-aware update: it makes the final result less sensitive to the choice of learning rate parameters.
No comments:
Post a Comment