Thinking about how Sparse Auto-encoders (SAEs) aim to learn a sparse over-complete basis (where you are trying to triangulate a larger number of sources than you have signals) got me thinking about Independent Component Analysis again. In particular, I wanted to see if I could articulate a mapping between ICA and SAEs. This would provide a more mathematical framework for thinking about SAEs. I do so in this post by walking through this paper: Lewicki and Sejnowski
Lewicki and Sejnowski
This paper investigates “using a Laplacian prior” to learn “representations that are sparse and [a] non-linear function of the data” as a “method for blind source separation of fewer mixtures than sources”.
Motivations
- Overcomplete representations “(i.e. more basis vectors than input variables), can provide a better representation, because the basis vectors can be specialized for a larger variety of features present in the entire ensemble of data”
- But, “a criticism of overcomplete representations is that they are redundant, i.e. a given data point may have many possible representations”
- This paper is about imposing a prior probability of the basis coefficients which “specifies the probability of the alternative representations”. I’m not sure I understand this for now.
Problem Setup
\[\begin{align*} x & = As + \varepsilon \end{align*}\]where $A \in \mathbb{R}^{m \times n}, m < n$. $x$ is the observation (short vector), while $s$ is the hidden “true” source signal vector. Further, they assume “Gaussian additive noise so that $\log P(x \mid A, s) \propto - \lambda (x - As)^2 / 2$.” What a cumbersome way of simplying saying:
\[\varepsilon \sim \mathcal{N} (0, \lambda ')\]They further define a “density for the basis coefficients, $P(s)$, which specifies the probability of the alternative representations. The most probable representation, $\hat{s}$, is found by maximizing the posterior distribution:”
\[\hat{s} = \max_s P(s \mid A, x) = \max_s P(s) P(x \mid A, s)\]They really do need an editor for this paper. Firstly, they mean to say:
\[\hat{s} = \text{argmax}_s P(s \mid A, x) = \text{argmax}_s P(s) P(x \mid A, s)\]And I should note that this is simply Bayes’ rule being applied. In this case, $P(s)$, a.k.a the prior, has been hydrated out and assumed to be Laplacian.
Approaches considered
They considered these optimization approaches:
- Find $\text{argmax}_s P(s \mid A, x)$ by using the gradient of the log posterior ($\log P(s \mid A, x)$).
- Use linear programming methods to find $A$ and $s$ to maximize $\text{argmax}_s P(s \mid A, x)$ while minimizing $\mathbf{1}^\top s = |s|_1$. This is exactly equivalent to the objective of SAEs.
Learning Objective
“The learning objective is to adapt $A$ to maximize the probability of the data which is computed by marginalizing over the internal state:”
\[\begin{align*} P(x \mid A) & = \int P(s) P(x \mid A, s) \text{ } ds \end{align*}\]This, I understand. A helpful note is that $s$ is distributed around some mean ($\hat{s}$), and that distribution is usually Gaussian, but in this paper, they propose for it to be Laplacian. Remember also that while $s \sim \text{Laplacian}$, the noise $\varepsilon$ in the data is still Gaussian.
They continue: “this integral cannot be evaluated analytically but can be approximated with a Gaussian integral (hence why it’s usually Gaussian) around $\hat{s}$, yielding:”
\[\begin{align*} \log P(x \mid A) \approx \text{const.} + \log P(\hat{s}) - \frac{\lambda}{2}(x - A\hat{s})^2 - \frac{1}{2} \log \text{det} H \end{align*}\]“where $H$ is the Hessian of the log posterior at $\hat{s}$.” This, I did not understand, but I was able to trace through the derivation with Claude’s help and so I’ll write it down before it’s lost once again to the ether.
Derivation of Approximation
First, we denote the log of the integrand as $f(s)$:
\[\begin{align*} f(s) & = \log P(s) + \log P(x \mid A) \\ \therefore P(x \mid A) & = \int e^{f(s)} \text{ } ds \end{align*}\]We know that the mean of a Gaussian (and Laplacian for that matter) distribution has the maximum probability density. This is a useful fact about $\hat{s}$, which we will try to incorporate by expressing $f(s)$ in terms of $f(s’)$ using the Taylor expansion:
\[\begin{align*} f(s) & \approx f(\hat{s}) + \nabla f(\hat{s})^\top \left(s - \hat{s} \right) + \frac{1}{2} \left( s - \hat{s} \right)^\top H \left( s - \hat{s}\right) \end{align*}\]Since, by Bayes rule,
\[\begin{align*} \log P (s \mid x, A) & = \log P(s) + \log P(x \mid A, s) - \log P(x \mid A), \\ \Rightarrow \log P(s \mid x, A) & \propto \log P(s) + \log P(x \mid A, s) = f(s) \end{align*}\]Since $\hat{s} = \text{argmax}_s P(s \mid x, A)$, this also means that $f(s)$ is maximized at $\hat{s}$. Hence, $\nabla f(\hat{s})$ at $\hat{s}$ is $0$. Therefore:
\[f(s) \approx f(\hat{s}) + \frac{1}{2}\left( s - \hat{s} \right)^\top H \left(s - \hat{s} \right)\]Substituting back into the integral, we have:
\[\begin{align*} P(x \mid A) & \approx \int e^{f(\hat{s})} e^{\frac{1}{2}(s - \hat{s})^\top H (s - \hat{s})} \text{ } ds \\ & = e^{f(\hat{s})} \int e^{-\frac{1}{2}(s - \hat{s})^\top K (s - \hat{s})} \text{ } ds \end{align*}\]Where $K = -H$. Crucially, the second term is a Gaussian integral, with a known solution below.Note that because $H$ is the Hessian of a concave quadratic, $H$ is negative definite, and has a positive determinant if $s$ is even-dimensional, and negative determinant if $s$ is odd-dimensional. Since we’re using $K = -H$, this is not a problem anymore because $K$ is positive semidefinite and has a non-negative determinant.
\[\begin{align*} \int e^{-\frac{1}{2}(s - \hat{s})^\top K (s - \hat{s})} \text{ } ds & = \sqrt{\frac{(2 \pi)^d}{| K |}} \end{align*}\]Where $|K |$ is the determinant of $K$. Therefore:
\[\begin{align*} P(x \mid A) & \approx e^{f(\hat{s})} \int e^{-\frac{1}{2}(s - \hat{s})^\top K (s - \hat{s})} \text{ } ds \\ & = e^{f(\hat{s})} (2 \pi)^\frac{d}{2} \cdot | K | ^{-\frac{1}{2}} \\ \Rightarrow \log P(x \mid A) & = f (\hat{s}) + \frac{d}{2} \log (2 \pi) - \frac{1}{2} \log |K| \\ & = \log P(\hat{s}) + \log P(x \mid A, \hat{s}) + \frac{d}{2} \log (2 \pi) - \frac{1}{2} \log |K| \\ \end{align*}\]At this point, we pretty much have what we want. We just have to note that for the $\log P(x \mid A, \hat{s})$ term, since Gaussian PDF is given by some scalar multiple of $\exp(-(x - \mu)^2)$, we simply have:
\[\begin{align*} \log P(x \mid A, s) = -k \left( x - A \hat{s} \right)^2 + \text{const} \end{align*}\]And further noting that the $\frac{d}{2} \log (2 \pi)$ term is a constant, we have our approximation:
\[\begin{align*} \log P(x \mid A) \approx \text{const} + \log P (\hat{s}) - k(x - A\hat{s})^2 - \frac{1}{2} \log |K| \end{align*}\]Where $|K|$ is explicitly the log determinant of $\text{abs}(H)$, and not the determinant of $H$, as the paper suggests.
Learning Rule
As with normal gradient ascent algorithms, our learning rule is to update $A$ with $A + \Delta A$, where $\Delta A$ is the gradient of the maximization objective, in this case: $\log P(X \mid A)$. I’ll skim the derivation of the learning rule:
First term:
\[\begin{align*} \nabla_A \log P(\hat{s}) & = \nabla_{\hat{s}} \log P(\hat{s}) \cdot \nabla_{A} \hat{s} \\ \end{align*}\]Things to note / denote:
- $z = \nabla_s \log P(s)$
- $W \approx A^{-\top}$, in the sense that rows of $W^\top$ corresponding to non-zero $s$ indices are the same as those of $A^{-\top}$, and hence $A^\top W^\top = I$
So first term becomes (I don’t fully follow the assumptions that allow the conflation of $\hat{s}$ and $s$, and $A^{-\top}$ and $W^\top$, but I can big-picture follow the chain rule):
\[\begin{align*} \nabla_A \log P(\hat{s}) & = -W^\top z s^\top \end{align*}\]Second term:
\[- k (x - A\hat{s})^2\]This is simply a noise term, and the Gaussian noise is irreducible error. There’s no gradient w.r.t $A$ from this term.
Third term: I’m not even going to trace through the derivation for this one:
\[\begin{align*} \nabla_A \frac{1}{2} \log \text{det} (H) & = \lambda A H^{-1} - 2W^\top y \hat{s}^\top \end{align*}\]Where $y$ “hides a computation involving the inverse Hessian… It can be shown that if $ \log P(s)$ and its derivatives are smooth, $y$ vanishes for large $\lambda$.”
Putting it together:
\[\begin{align*} \nabla_A \log P(x \mid A) & = -W^\top z s^\top - \lambda AH^{-1} + A y s^\top \\ & = -W^\top z s^\top - \lambda AH^{-1} \text{ (last term vanished)} \end{align*}\]The authors then choose to normalize this by $AA^\top$, which they do not explain, and I do not understand. A wild guess is that usually, $AA^\top$ captures geometry of $A$, in the sense that $AA^\top$’s eigenvalues will tell you the stretching factor in each dimension, and hence the “variance” in each dimension, and you might hence want to scale up your updates accordingly.
\[\begin{align*} AA^\top \nabla_A \log P(x \mid A) & = A z s^\top - \lambda A A^\top AH^{-1} \end{align*}\]More assumptions: if $\lambda$ is large (low noise, as large $\lambda \Rightarrow$ low standard deviation), “then the Hessian is dominated by $\lambda A^\top A$, and we have”:
\[\begin{align*} - \lambda A A^\top AH^{-1} & = \lambda A A^\top A (A^\top A + B)^{-1} \approx -A \end{align*}\]And we have this final update rule:
\[\begin{align*} A & := - \alpha (A z s^\top + A) \end{align*}\]Comparison with SAEs
vs SAEs
Denoting a simplified SAE encoder function as:
\[\begin{align*} s & = \text{ReLU} (W_\text{enc}x + b_{enc}) \\ x' & = W_\text{dec} s + b_{dec} \end{align*}\]Both SAEs and ICA are supposing that the data is compressed and are trying to do decompression, but we can already see 2 differences. The first is that relying on $\text{ReLU}$ to hydrate your individual features presupposes that your data is sparse. In particular, because the latent space will always look something like this:
“Superposition - An Actual Image of Latent Spaces”
You can notice that not all types of data vectors $x$ can be reconstructed. If $x$ is one-hot, you’re in good shape, because the embeddings of $x$ ($s$) will sit along its feature vector and be only in the activation zone of that one feature. However, if $x$ is two-hot, then the 2 features that are active better be one of the 6 pairs for which there is an overlap of those 2 features’ activation zones (corresponding to the edges of the hexagon in the latent zone plot). If $x$ is three-hot, you’re out of luck, because there isn’t a zone in the latent where the latent zones of 3 features overlap. In a 3D latent, such zones would correspond to a face, but you can see how density quickly hinders your reconstruction ability.
ICA does not rely on $\text{ReLU}$ to decompress features. ICA relies on statistical arguments of minimizing covariance of features and maximizing the a posteriori likelihood of the data given some prior, which brings us to the second difference.
The update rule of Lewicki and Sejnowski has $z$ in it. Remember that $z = \nabla_s \log P(s)$. $P(s)$ in particular, is defined by the user; in this case, it is Laplacian, simply because we said so. This variant of ICA allows us to build in explicit hypotheses about the distribution of “true feature” / source signal activations (coefficients) as our prior.
Looking at activation histograms that Anthropic has generated, I’d say that the Laplacian distribution is a reasonable approximation to a large number of feature activation distributions, and that there is a possible research question worth exploring: what if we used Lewicki and Sejnowski to try and find features instead of SAEs?