Making transformers thinner in the middle can save compute and slightly improve language modeling
Researchers tested a simple change to transformer language models: let some layers be narrower than others instead of making every layer the same width. They focus on an "X-shaped" profile that keeps the early and late layers wide but narrows the middle layers. Across decoder-only language models from about 200 million to 2 billion parameters (and a 3 billion parameter mixture-of-experts model), these X-shaped models beat parameter-matched constant-width models on language modeling loss while using less compute and memory in some measures.
The key technical idea is to give different layers different hidden sizes, while keeping a single residual stream that is as wide as the widest layer. When a layer shrinks, the extra coordinates are simply dropped (truncated). When a layer grows again, the model restores coordinates by copying them from the most recent earlier layer that had those coordinates. This copy-based, parameter-free resizing avoids adding projection layers and keeps the skip connection structure consistent.
In experiments the X-shaped models consistently outperform uniform-width baselines that have the same total number of parameters. Reported gains include about a 3% relative improvement in perplexity for models between 200M and 2B parameters, and reductions in resource use: roughly a 10% smaller key-value (KV) cache and about 3% lower floating-point operation counts (FLOPs) in the reported runs. The paper also reports larger, model-selection-style improvements under fitted scaling curves (about a 22% FLOP reduction) and a 15% reduction in KV cache and I/O cost in that analysis. These larger numbers come from different extrapolations and should be read as a separate, higher-level finding.
Why this matters: parameters and compute budgets are central to how people build and run language models. Because parameter counts grow with the square of layer width, keeping some layers narrow can free budget for other parts of the model and lower the average layer width. That reduces attention-related work that scales linearly with width, which in turn lowers memory needed for KV caches and the cost of moving activations during inference. The authors also show that the X-shaped design changes the kinds of internal representations the model forms and helps avoid a mid-layer collapse seen in some baselines.