Thinking about how Sparse Auto-encoders (SAEs) aim to learn a sparse over-complete basis (where you are trying to triangulate a larger number of sources than you have signals) got me thinking about Independent Component Analysis again. In particular, I wanted to see if I could articulate a mapping between ICA and SAEs. This would provide a more mathematical framework for thinking about SAEs. I do so in this post by walking through this paper: Lewicki and Sejnowski
Lewicki and Sejnowski
This paper investigates “using a Laplacian prior” to learn “representations that are sparse and [a] non-linear function of the data” as a “method for blind source separation of fewer mixtures than sources”.
Motivations
- Overcomplete representations “(i.e. more basis vectors than input variables), can provide a better representation, because the basis vectors can be specialized for a larger variety of features present in the entire ensemble of data”
- But, “a criticism of overcomplete representations is that they are redundant, i.e. a given data point may have many possible representations”
- This paper is about imposing a prior probability of the basis coefficients which “specifies the probability of the alternative representations”. I’m not sure I understand this for now.
Problem Setup
\[\begin{align*} x & = As + \varepsilon \end{align*}\]where $A \in \mathbb{R}^{m \times n}, m < n$. $x$ is the observation (short vector), while $s$ is the hidden “true” source signal vector. Further, they assume “Gaussian additive noise so that $\log P(x \mid A, s) \propto - \lambda (x - As)^2 / 2$.” Whatever:
\[\varepsilon \sim \mathcal{N} \left(0, \frac{1}{\lambda}\right)\]They further define a “density for the basis coefficients, $P(s)$, which specifies the probability of the alternative representations. The most probable representation, $\hat{s}$, is found by maximizing the posterior distribution:”
\[\hat{s} = \max_s P(s \mid A, x) = \max_s P(s) P(x \mid A, s)\]Firstly, they mean to say:
\[\hat{s} = \text{argmax}_s P(s \mid A, x) = \text{argmax}_s P(s) P(x \mid A, s)\]And I should note that this is simply Bayes’ rule being applied. In this case, $P(s)$, a.k.a the prior, has been hydrated out and assumed to be Laplacian.
Approaches considered
They considered these optimization approaches:
- Find $\text{argmax}_s P(s \mid A, x)$ by using the gradient of the log posterior ($\log P(s \mid A, x)$).
- Use linear programming methods to find $A$ and $s$ to maximize $\text{argmax}_s P(s \mid A, x)$ while minimizing $\mathbf{1}^\top s = |s|_1$. This is exactly equivalent to the objective of SAEs.
Learning Objective
“The learning objective is to adapt $A$ to maximize the probability of the data which is computed by marginalizing over the internal state:”
\[\begin{align*} P(x \mid A) & = \int P(s) P(x \mid A, s) \text{ } ds \end{align*}\]This, I understand. A helpful note is that $s$ is distributed around some mean ($\hat{s}$), and that distribution is usually Gaussian, but in this paper, they propose for it to be Laplacian. Remember also that while $s \sim \text{Laplacian}$, the noise $\varepsilon$ in the data is still Gaussian.
They continue: “this integral cannot be evaluated analytically but can be approximated with a Gaussian integral (hence why it’s usually Gaussian) around $\hat{s}$, yielding:”
\[\begin{align*} \log P(x \mid A) \approx \text{const.} + \log P(\hat{s}) - \frac{\lambda}{2}(x - A\hat{s})^2 - \frac{1}{2} \log \text{det} H \end{align*}\]“where $H$ is the Hessian of the log posterior at $\hat{s}$.” This, I did not understand, but I was able to trace through the derivation with Claude’s help and so I’ll write it down before it’s lost once again to the ether.
Derivation of Approximation
First, we denote the log of the integrand as $f(s)$:
\[\begin{align*} f(s) & = \log P(s) + \log P(x \mid A, s) \\ \therefore P(x \mid A) & = \int e^{f(s)} \text{ } ds \end{align*}\]We know that the mean of a Gaussian (and Laplacian for that matter) distribution has the maximum probability density. This is a useful fact about $\hat{s}$, which we will try to incorporate by expressing $f(s)$ in terms of $f(\hat{s})$ using the Taylor expansion:
\[\begin{align*} f(s) & \approx f(\hat{s}) + \nabla f(\hat{s})^\top \left(s - \hat{s} \right) + \frac{1}{2} \left( s - \hat{s} \right)^\top H \left( s - \hat{s}\right) \end{align*}\]Since, by Bayes rule,
\[\begin{align*} \log P (s \mid x, A) & = \log P(s) + \log P(x \mid A, s) - \log P(x \mid A), \\ \Rightarrow \log P(s \mid x, A) & \propto \log P(s) + \log P(x \mid A, s) = f(s) \end{align*}\]Since $\hat{s} = \text{argmax}_s P(s \mid x, A)$, this also means that $f(s)$ is maximized at $\hat{s}$. Hence, $\nabla f(\hat{s})$ at $\hat{s}$ is $0$. Therefore:
\[f(s) \approx f(\hat{s}) + \frac{1}{2}\left( s - \hat{s} \right)^\top H \left(s - \hat{s} \right)\]Substituting back into the integral, we have:
\[\begin{align*} P(x \mid A) & \approx \int e^{f(\hat{s})} e^{\frac{1}{2}(s - \hat{s})^\top H (s - \hat{s})} \text{ } ds \\ & = e^{f(\hat{s})} \int e^{-\frac{1}{2}(s - \hat{s})^\top K (s - \hat{s})} \text{ } ds \end{align*}\]Where $K = -H$. Crucially, the second term is a Gaussian integral, with a known solution below.Note that because $H$ is the Hessian of a concave quadratic, $H$ is negative definite, and has a positive determinant if $s$ is even-dimensional, and negative determinant if $s$ is odd-dimensional. Since we’re using $K = -H$, this is not a problem anymore because $K$ is positive semidefinite and has a non-negative determinant.
\[\begin{align*} \int e^{-\frac{1}{2}(s - \hat{s})^\top K (s - \hat{s})} \text{ } ds & = \sqrt{\frac{(2 \pi)^d}{| K |}} \end{align*}\]Where $|K |$ is the determinant of $K$. Therefore:
\[\begin{align*} P(x \mid A) & \approx e^{f(\hat{s})} \int e^{-\frac{1}{2}(s - \hat{s})^\top K (s - \hat{s})} \text{ } ds \\ & = e^{f(\hat{s})} (2 \pi)^\frac{d}{2} \cdot | K | ^{-\frac{1}{2}} \\ \Rightarrow \log P(x \mid A) & = f (\hat{s}) + \frac{d}{2} \log (2 \pi) - \frac{1}{2} \log |K| \\ & = \log P(\hat{s}) + \log P(x \mid A, \hat{s}) + \frac{d}{2} \log (2 \pi) - \frac{1}{2} \log |K| \\ \end{align*}\]At this point, we pretty much have what we want. We just have to note that for the $\log P(x \mid A, \hat{s})$ term, since Gaussian PDF is given by some scalar multiple of $\exp(-(x - \mu)^2)$, we simply have:
\[\begin{align*} \log P(x \mid A, s) = -k \left( x - A \hat{s} \right)^2 + \text{const} \end{align*}\]And further noting that the $\frac{d}{2} \log (2 \pi)$ term is a constant, we have our approximation:
\[\begin{align*} \log P(x \mid A) \approx \text{const} + \log P (\hat{s}) - k(x - A\hat{s})^2 - \frac{1}{2} \log |K| \end{align*}\]Where $|K|$ is explicitly the absolute value of the determinant of $H$, and not the determinant of $H$, as the paper suggests.
Learning Rule
As with normal gradient ascent algorithms, our learning rule is to update $A$ with $A + \Delta A$, where $\Delta A$ is the gradient of the maximization objective, in this case: $\log P(X \mid A)$. I’ll skim the derivation of the learning rule:
First term:
\[\begin{align*} \nabla_A \log P(\hat{s}) & = \nabla_{\hat{s}} \log P(\hat{s}) \cdot \nabla_{A} \hat{s} \\ \end{align*}\]Things to note / denote:
- $z = \nabla_s \log P(s)$
- $W \approx A^{-\top}$, in the sense that rows of $W^\top$ corresponding to non-zero $s$ indices are the same as those of $A^{-\top}$, and hence $A^\top W^\top = I$
So first term becomes (I don’t fully follow the assumptions that allow the conflation of $\hat{s}$ and $s$, and $A^{-\top}$ and $W^\top$, but I can big-picture follow the chain rule):
\[\begin{align*} \nabla_A \log P(\hat{s}) & = -W^\top z s^\top \end{align*}\]Second term:
\[- k (x - A\hat{s})^2\]This is simply a noise term, and the Gaussian noise is irreducible error. There’s no gradient w.r.t $A$ from this term.
Third term: I’m not even going to trace through the derivation for this one:
\[\begin{align*} \nabla_A \frac{1}{2} \log \text{det} (H) & = \lambda A H^{-1} - 2W^\top y \hat{s}^\top \end{align*}\]Where $y$ “hides a computation involving the inverse Hessian… It can be shown that if $ \log P(s)$ and its derivatives are smooth, $y$ vanishes for large $\lambda$.”
Putting it together:
\[\begin{align*} \nabla_A \log P(x \mid A) & = -W^\top z s^\top - \lambda AH^{-1} + 2A y s^\top \\ & = -W^\top z s^\top - \lambda AH^{-1} \text{ (last term vanished)} \end{align*}\]The authors then choose to normalize this by $AA^\top$, which they do not explain, and I do not understand. A wild guess is that usually, $AA^\top$ captures geometry of $A$, in the sense that $AA^\top$’s eigenvalues will tell you the stretching factor in each dimension, and hence the “variance” in each dimension, and you might hence want to scale up your updates accordingly.
\[\begin{align*} AA^\top \nabla_A \log P(x \mid A) & = A z s^\top - \lambda A A^\top AH^{-1} \end{align*}\]More assumptions: if $\lambda$ is large (low noise, as large $\lambda \Rightarrow$ low standard deviation), “then the Hessian is dominated by $\lambda A^\top A$, and we have”:
\[\begin{align*} - \lambda A A^\top AH^{-1} & = \lambda A A^\top A (A^\top A + B)^{-1} \approx -A \end{align*}\]And we have this final update rule:
\[\begin{align*} A & := - \alpha (A z s^\top + A) \end{align*}\]Comparison with SAEs
vs SAEs
Denoting a simplified SAE encoder function as:
\[\begin{align*} s & = \text{ReLU} (W_\text{enc}x + b_{enc}) \\ x' & = W_\text{dec} s + b_{dec} \end{align*}\]Both SAEs and ICA are supposing that the data is compressed and are trying to do decompression, but we can already see 2 differences. The first is that relying on $\text{ReLU}$ to hydrate your individual features presupposes that your data is sparse. In particular, because the latent space will always look something like this:
“Superposition - An Actual Image of Latent Spaces”
You can notice that not all types of data vectors $x$ can be reconstructed. If $x$ is one-hot, you’re in good shape, because the embeddings of $x$ ($s$) will sit along its feature vector and be only in the activation zone of that one feature. However, if $x$ is two-hot, then the 2 features that are active better be one of the 6 pairs for which there is an overlap of those 2 features’ activation zones (corresponding to the edges of the hexagon in the latent zone plot). If $x$ is three-hot, you’re out of luck, because there isn’t a zone in the latent where the latent zones of 3 features overlap. In a 3D latent, such zones would correspond to a face, but you can see how density quickly hinders your reconstruction ability.
ICA does not rely on $\text{ReLU}$ to decompress features. ICA relies on statistical arguments of minimizing covariance of features and maximizing the a posteriori likelihood of the data given some prior, which brings us to the second difference.
The update rule of Lewicki and Sejnowski has $z$ in it. Remember that $z = \nabla_s \log P(s)$. $P(s)$ in particular, is defined by the user; in this case, it is Laplacian, simply because we said so. This variant of ICA allows us to build in explicit hypotheses about the distribution of “true feature” / source signal activations (coefficients) as our prior.
Looking at activation histograms that Anthropic has generated, I’d say that the Laplacian distribution is a reasonable approximation to a large number of feature activation distributions, and that there is a possible research question worth exploring: what if we used Lewicki and Sejnowski to try and find features instead of SAEs?
Experiment Log
E1. Data Generation
Our goal is to have data that looks like this:
Desired Dataset (from “Examples” of Lewicki & Sejnowski)
A Laplacian prior does not give a dataset that looks like this. It gives a symmetric spherical dataset, much like a multivariate Gaussian. I then found that they wrote that “the elements of $s$ [are] distributed according to an exponential distribution with unit mean.” I was able to recover something like that using an exponential prior:
Exponential Prior Dataset ($\lambda = 0.5$, i.e. mean of 2)
They further state that “identical results were obtained by drawing $s$ from a Laplacian prior (positive and negative coefficients).” This is incredible, because it means that the update rule is readily applicable to the feature (non-negative coefficients) extraction setting. And indeed, the PDF for each half of the Laplace distribution belongs to the same family of distributions as the PDF of the exponential distribution. Their shapes are exactly the same up to a scaling factor, and the scaling factor is completely irrelevant in the context of gradient descent! 😄 😄
E2. Solving for $\hat{s}$ Given a Fixed $A$
Remember that the update rule is:
\[A := - \alpha (A z \hat{s}^\top + A)\]So first, we have to solve for $\hat{s}$. The paper states that “coefficients were solved using BPMPD,” where BPMPD is a primal-dual interior point algorithm. Sorry Stephen Boyd, as much as I loved convex optimization, I’m not about to re-write a convex optimization algorithm. We will instead using good old gradient descent.
Given a fixed $A$ and a data matrix $X \in \mathbb{R}^{n \times p}$, we shall find $\hat{s}$. Remember that this is the definition of $\hat{s}$:
\[\hat{s} = \text{argmax}_s P(s \mid A, x) = \text{argmax}_s P(s) P(x \mid A, s)\]Log it:
\[\hat{s} = \text{argmax}_s \log P(s \mid A, x) = \text{argmax}_s \left\{ \log P(s) + \log P(x \mid A, s) \right\}\]The first term is easy to compute; just plug $s$ into the equation for the Laplace Distribution PDF:
\[\begin{align*} P(s) & = \text{(Some constant)} \exp \left( - \frac{|s - (\mu_s = 0)|}{\lambda} \right) \\ \log P(s) & = \text{(Some constant) } - \frac{\|s\|_1}{\lambda} \end{align*}\]The second term is harder. We know that the noise model is:
\[x | A, s \sim \mathcal{N}(As, \sigma^2 I)\]Therefore, the likelihood and log-likelihood are:
\[\begin{align*} P(x \mid A, s) &= (2 \pi \sigma^2)^{-\frac{m}{2}} \exp \left(-\frac{\|x - As\|_2^2}{2 \sigma^2} \right) \\ \log P(x \mid A, s) &= - \frac{m}{2} \log (2 \pi \sigma^2 ) - \frac{\|x - As\|_2^2}{2 \sigma^2} \end{align*}\]where $m$ is the signal (measurement) dimensionality. Putting these together, we get the LASSO objective, aka minimize:
\[\begin{align*} \mathcal{L}_s = \frac{\|x - As\|_2^2}{2 \sigma^2} + \frac{\|s\|_1}{\lambda} \end{align*}\]This looks scarily similar to the SAE cost function:
\[\mathcal{L}_\text{SAE}(A, s) = \text{MSE}(x, As) + \lambda_\text{sparsity} \|s\|_1\]But there are 2 main differences. In training SAEs, the optimization is done on $A$ and $s$ simultaneously, whereas in ICA, the optimization is done on $s$ first to find $\hat{s}$. Then, in addition to this, we still have to find $z$, to optimize $A$. I’m not sure how isomorphic these end up being.
Implementing the learning of $\hat{s}$ is rather straightforward, with the update equation simply being:
\[s := \text{ReLU}(s - \nabla_s \mathcal{L}_s)\]E3 Updating $A$ given $\hat{s}$
The update equation here is simply from the paper:
\[A := - \alpha A(z s^\top + I)\]Since $z = \nabla_s \log P(\hat{s})$, and $s \sim \text{Laplace}(0, \lambda)$, it turns out that $z$ is simply $\frac{1}{\lambda} \text{sign}(\hat{s})$. The update to $A$ is very problematic, and while I was able to get perfect learning behavior when initializing $A$ with a perturbed version of $A$:
There is a failure mode where the updates to $A$ fall into a runaway regime, resulting in exploding losses:
I’ve noticed that this happens when the learning rate ($\alpha$) is unsuitable (either too high, or no learning rate scheduler is used), and / or the initialization of $A$ is random:
Solving this problem / figuring out why this happens is crucial before trying to implement this in real language models to extract features, because our features are definitely going to be randomly initialized, and we don’t know how many features (columns) to provision for in $A$.
E4 Re-Examining the Update Rule
\[A := - \alpha A(z \hat{s}^\top + I)\]Being forced to stare at this update rule, I now notice that this update rule doesn’t seem to be proportional to the error the way neural-network / general loss back-propagation update rules usually are. In particular, consider the case of perfect data reconstruction:
- $z$ is still $\frac{1}{\lambda} \text{sign}(\hat{s})$, as is the case in imperfect reconstruction
- $\hat{s}^\top$ is still $\hat{s}^\top$; it’s not as if $\hat{s}$ vanishes as reconstruction error vanishes
- $I$ is still $I$
I wonder if this is a problem. In particular, perhaps the fact that we’re throwing in a $\text{ReLU}$ to make all the source feature coefficients ($\hat{s}$) non-negative is disrupting a symmetry that is crucial to the update rule working.
E4.1 Laplace vs Exponential
So, I tested removing the $\text{ReLU}$ that was applied to $\hat{s}$ after each iteration of learning $\hat{s}$. This effectively makes it such that $s \sim \text{Laplace}$ instead of $s \sim \text{Exponential}$. These are the results I got:
There’s a new problem where the features recovered in $A$ do not rotate to align with the true features anymore. There are 2 observations that make it unclear why exactly achieving perfectly alignment is dis-preferred here:
- Imagine that $A$ were $A_\text{true}$, but flipped vertically. If we allowed negative coefficients, we can achieve the same sparsity that is achievable by $A_\text{true}$. All the coefficients just have to have the opposite sign. There is no incentive to rotate. However, this is not true if $A$ were some general rotation of $A_\text{true}$, and not exactly vertically flipped, as is the case shown above.
- Sparsity can be reduced by increasing the magnitude of the features in $A$.
In any case, this seems to be introducing a new problem, so it’d be more fruitful it seems to look elsewhere.
E4.2 Gaussian Noise?
In my prior experiments, I did not add Gaussian noise to $x$, as the paper supposes. Multiple times in the paper they allude to noise being crucial:
- “Recently Olshausen and Field (1996) presented an algorithm that allows an overcomplete basis to be learned. This algorithm […], including a tendency to breakdown in the case of low noise levels and when learning bases with higher degrees of overcompleteness.”
- Many times in the paper they simply write that they are assuming low noise.
So… I just added Gaussian Noise:
\[x = As + \left(\varepsilon \sim \mathcal{N}(0, 0.05^2) \right)\]Didn’t work; still exhibited the increase of error past a certain point: