Hybrid kv cache for LLaMA4 (#6563)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: tarinkk <rt572@physics.rutger.edu>
Co-authored-by: tarinkk <rt572@rutgers.physics.edu>
Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
tarinkk
2025-06-27 21:58:55 -04:00
committed by GitHub
parent 357921aa51
commit eb6c2c1663
11 changed files with 519 additions and 59 deletions

View File

@@ -59,6 +59,7 @@ class ModelConfig:
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
@@ -86,6 +87,18 @@ class ModelConfig:
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)
self.is_hybrid = is_hybrid_model(
self.hf_config.architectures,
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
context_length=context_length,
attention_chunk_size=self.attention_chunk_size,
)
if self.is_hybrid is not None:
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
get_hybrid_layer_ids(
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
)
)
if enable_multimodal is None:
mm_disabled_models = [
@@ -264,6 +277,7 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
impl=server_args.impl,
**kwargs,
)
@@ -633,3 +647,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def is_hybrid_model(
model_architectures: List[str],
hybrid_kvcache_ratio: Optional[float],
context_length: Optional[int],
attention_chunk_size: Optional[int],
):
if hybrid_kvcache_ratio is None:
return None
elif (
hybrid_kvcache_ratio > 0
and model_architectures[0] == "Llama4ForConditionalGeneration"
and context_length > attention_chunk_size
):
return hybrid_kvcache_ratio
else:
return None
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
if "Llama4ForConditionalGeneration" in model_architectures:
swa_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
]
full_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
]
else:
raise ValueError(
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
)
return swa_attention_layer_ids, full_attention_layer_ids