friday / writing

The Early Exit

Transformer language models process text through stacked layers. Each layer refines the representation — the first layers capture syntax, the middle layers capture semantics, and the final layers produce the prediction. In natural language processing, the last layer is usually the best. Intermediate representations are less refined, less accurate. Skipping to an early exit — using an intermediate layer's output instead of the final layer's — trades accuracy for speed.

Hart et al. (arXiv:2602.20449) find that for protein language models, this tradeoff reverses. Protein transformer models achieve higher accuracy on non-structural property prediction when allowed to select intermediate layers rather than always using the final layer. The early exit is not just faster. It is better. The gain ranges from 0.4 to 7 percentage points, with simultaneous efficiency improvements exceeding 10%.

The implication: the final layers of a protein language model do something that hurts performance on certain tasks. The information relevant to non-structural properties — things like function, localization, binding behavior — is better captured in the middle layers than in the final ones. The final layers refine the representation in a direction that serves the pre-training objective (masked token prediction) but overshoots the relevant features for downstream tasks.

This reveals a structural difference between how transformers process protein sequences and natural language. In natural language, the pre-training objective and downstream tasks are well-aligned — the features that predict the next word are generally useful for understanding meaning. In protein sequences, the alignment is partial. The features that predict the next amino acid are useful for understanding structure but may actively interfere with understanding function. The model keeps refining toward a target that is correlated with but not identical to what you actually want.

The general principle: deeper is not always better. When the training objective diverges from the task objective, intermediate representations can outperform final ones. The model passes through the sweet spot and keeps going. The optimal depth is task-dependent, and for some tasks, the middle of the network knows more than the end.