ACE mitigates simplicity bias
One potential obstacle to successful concept extrapolation is the difference in simplicity between spuriously correlated features.
For example, consider the HappyFaces dataset. Every time that we trained a vanilla (no concept extrapolation) classifier on the HappyFaces training data, we got a model which classified based on the text, rather than facial expression. More specifically, we always got a model which looked for the letter Y. Even more specifically, we always got a model which looked for straight lines meeting in upward facing 60 degree angles.
The reason for this comes down to inductive biases: even though there are many possible solutions to a given problem, ML models tend to preferentially learn certain ones. Here, our models have a strong inductive bias towards learning to classify by the presence or absence of the letter Y (rather than learning to classify by the presence or absence of the letter H or of a smile). The reason our models preferentially learn the classify-by-Y solution probably has something to do with simplicity bias, the tendency of models to prefer "simpler" solutions. For the rest of this post, we’ll round off the discussion to simplicity bias, but this is potentially a simplification; other inductive biases may be in play.
Do differences in simplicity make concept extrapolation harder?
If we would like to do concept extrapolation on HappyFaces – that is, to get a model to learn both the classify-by-text and classify-by-face solutions – then we need to get our model to learn a solution (classify-by-face) which it would not ordinarily learn. In other words, we need to push against simplicity bias. In fact, it might seem that the relevant obstacle here is the difference in simplicity between the classify-by-text and classify-by-face solutions. A natural question is: is concept extrapolation more difficult when this simplicity difference is larger?
To give an example, consider a dataset which consists of labeled images, each of which is either (a) a photograph of a smiling person on top of a red rectangle, or (b) a photograph of a not-smiling person on top of a blue rectangle. (As with HappyFaces, let's imagine that this dataset is accompanied by an unlabelled dataset consisting of images of a smiling person on top of a blue rectangle and vice-versa, as well as images of type (a) and (b).) For this dataset, the difference in simplicity between the classify-by-rectangle-color and classify-by-face solutions would seem to be larger than for HappyFaces.
A dataset in which the difference in simplicity between the spuriously correlated features is larger than in HappyFaces.
In this setting, our question becomes: will our concept extrapolation techniques have a harder time identifying both solutions for this dataset than they do for HappyFaces?
Making datasets of varying simplicity difference
To study this question we first needed to come up with a variety of features of varying simplicity. Our approach here was to use images from the CelebA dataset to generate datasets with various features spuriously correlated. Then we trained a network on the dataset and checked which feature it learned to use; that was the simpler feature. For example, in one of our datasets, blond hair was spuriously correlated with wearing eyeglasses, and we found that our network learned to classify by whether the person was wearing eyeglasses. (Given that there are many varieties of blond hair, with sometimes subtle distinctions between blond and not-blond, it seems intuitive that the glasses/no-glasses classification boundary was simpler to learn.)
By performing many pairwise comparisons of this sort, we were able to come up with a crude ranking of some CelebA features by complexity. For example, we were able to determine that eyeglasses/no-eyeglasses < blond/not-blond < young/not-young in terms of simplicity. In fact, young/not-young was the most complex feature among the ones we considered. Given this ranking, we could make datasets – with labeled and unlabelled data, as in HappyFaces – of varying simplicity differences.
The less-than arrow indicates that we expect the dataset on the left to have a smaller simplicity difference than the one on the right. (It’s not important that the rectangles in the third dataset are on the side instead of at the bottom.)
Finally, we tried to do concept extrapolation on these datasets: for each dataset we used Aligned AI’s concept extrapolation algorithm to train a two-headed classifier with each head learning to classify by a different feature.
The results: we found that for most of the datasets we’d constructed, our concept extrapolation algorithm had no trouble learning both of the spuriously correlated features; the concept extrapolation worked. However, for a few of our datasets with the largest simplicity differences, the concept extrapolation training became extremely path-dependent, sometimes successfully learning both of the features, but sometimes only learning the simpler feature and failing to learn the more complex feature entirely.
We turned to the question of how to mitigate this failure of concept extrapolation for datasets in which the spuriously correlated features had a large simplicity difference.
Mitigating the effect of simplicity difference
When trying to improve our concept extrapolation techniques for datasets with large simplicity difference, we tried two interventions:
selecting our classifier by using a model which had a more sophisticated pretraining;
using the linear probing then fine-tuning (LP-FT) technique from Kumar et al., 2022.
To explain these ideas, we’ll need to dig in a bit to the architecture of our models. At a high level, our models looked like this:
This is a very common architecture: given an input (e.g. an image from our dataset in which having blond hair was spuriously correlated with wearing eyeglasses), first feed that input through a model which converts the input into a vector which encodes “everything the model knows about the image.” This model is typically first pretrained on a much larger dataset (e.g. it could be pretrained for Imagenet classification). Then the vector which represents the base model’s knowledge about the input image is fed through a much smaller, simpler model (in this case, a simple linear classifier) which tries to classify the image from the base model’s vector representation of it.
One way of thinking about this is that during pretraining, the base model learns a bunch of concepts – for example, it might learn that “humans'' and “blond hair” are things – and learns to represent its inputs as best it can using the concepts it’s learned. Then our linear classifier tries to figure out which concepts encoded in these representations are relevant and useful for the classification task we’re training on.
So our first possible mitigation was to use a base model with a more sophisticated pretraining, i.e. one that we expected to lead to a base model which had learned more useful concepts during pretraining. In our case, this meant switching from a model pretrained Imagenet classifier (where useful concepts are things which help distinguish cats from dogs from lizards, etc.) to OpenAI’s CLIP model (which had been trained to match arbitrary images from the Internet with their captions, a task which incentivizes the learning of much richer concepts). We speculated that this would reduce the effects of simplicity difference because, once concepts had already been learned in pretraining, we expected them to be represented approximately equally simply for the purposes of fine-tuning on our concept extrapolation tasks; in other words, we speculated that pretraining flattened the complexity landscape among the learned features.
To explain our second attempted mitigation, we should note that when we trained our models for concept extrapolation, we trained the whole model (including the base model and the linear classifier) end-to-end. But in this setting, the results of Kumar et al. suggest that there’s a chance that our training process could distort the concepts which the base model had learned during pretraining, preventing the prelearned concepts from guiding concept extrapolation as much as we’d like. Kumar et al. also propose a possible mitigation: linear probing then fine-tuning (LP-FT), which involves first freezing the weights of the pretrained base model and only training the linear classifier, and then unfreezing the weights and training end-to-end.
The point is that this intervention should force the linear classifier to first identify the most useful prelearned concepts (without allowing those concepts to distort at all), and only allow the concepts themselves to change once the useful concepts have been first identified. Kumar et al. show that LP-FT helps improve classification performance in the case of distributional shift between the pretraining and fine-tuning datasets. We wanted to know if LP-FT would help for concept extrapolation as well.
And indeed, after both swapping out our base model for CLIP, we found that implementing LP-FT improved concept extrapolation on datasets with a large simplicity difference.
Concept extrapolation performance on a dataset in which young/not-young was spuriously correlated with red/blue rectangles on the left-hand side of the image. The base model for both experiments was CLIP.
Worst head validation loss was our metric of successful concept extrapolation, with smaller being better.
By Joe Kwon, Samuel Marks, Brady Pelkey, and Matthew Watkins