Adaptive Layer-skipping in Pre-trained LLMs

Adaptive Layer-skipping in Pre-trained LLMs

News: We have updated our training method for improved results! For details about the updated training method and datasets, please refer to our github & huggingface repo.
Xuan Luo1,   Weizhi Wang1,   Xifeng Yan1  
1University of California, Santa Barbara

Abstract

Various layer-skipping methods have been proposed to accelerate token generation in large language models (LLMs). However, they have overlooked a fundamental question: How do computational demands vary across the generation of different tokens? In this work, we introduce FlexiDepth, a method that dynamically adjusts the number of Transformer layers used in text generation. By incorporating a plug-in router and adapter, FlexiDepth enables adaptive layer-skipping in LLMs without modifying their original parameters. Introducing FlexiDepth to Llama-3-8B model achieves layer skipping of 8 layers out of 32, and meanwhile maintains the full 100% benchmark performance. Experimental results with FlexiDepth demonstrate that computational demands in LLMs significantly vary based on token type. Specifically, generating repetitive tokens or fixed phrases requires fewer layers, whereas producing tokens involving computation or high uncertainty requires more layers. Interestingly, this adaptive allocation pattern aligns with human intuition. To advance research in this area, we open sourced FlexiDepth and a dataset documenting FlexiDepth's layer allocation patterns for future exploration.

Main Idea

The computational demand varies when generating different tokens. Many transformer layers of pre-trained LLMs can be skipped without compromising performance. Below is an example of FlexiDepth demonstrating that many tokens utilize only a few layers.

Layer-skipping patterns
Layer-skipping patterns (Llama-3-8B-Instruct) for a language task (left) and a math task (right). The light-to-dark blue gradient represents layer usage from 16 to 32.

For more results about the layer skipping pattern of different tokens, please refer to flexipatterns

Architecture

Our architecture consists of:

  • A router to make layer-skipping decisions.
  • An adapter to align skipped hidden states and the processed hidden states.
FlexiDepth Architecture
The FlexiDepth layer. Left: Full-processing path where hidden states undergo the pre-trained attention and FFN modules. Right: Skipping path where hidden states bypass the attention module and are processed by a lightweight adapter. The router and adapter (in red) are the only trainable components within the FlexiDepth Block.

Results

We experimented on different tasks and observed a "bowl-like" pattern in layer usage, where earlier and later layers are utilized more, while middle layers are used less. We also found differences across tasks: "continue writing" tasks utilize more layers than "summarization" and "copy" tasks, and "product" tasks use more layers than "addition" and "repetition" tasks.

Layer Usage Patterns Left
Layer Usage Patterns Right
Percentage of tokens processed by transformer layers 17 to 32. The x-axis represents the layer index, and the y-axis represents the percentage of tokens processed by the layer.

We compared our model with other models to evaluate performance.

Influence of Alpha for Layer Skipping
Performance comparison based on Llama-3-8B-Instruct, which consists of 32 layers.
Retain % represents the percentage of average retained benchmark performance