Hybrid kv cache for LLaMA4 (#6563)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: tarinkk <rt572@physics.rutger.edu> Co-authored-by: tarinkk <rt572@rutgers.physics.edu> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -59,6 +59,7 @@ class ModelConfig:
|
||||
quantization: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
) -> None:
|
||||
|
||||
@@ -86,6 +87,18 @@ class ModelConfig:
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
)
|
||||
self.is_hybrid = is_hybrid_model(
|
||||
self.hf_config.architectures,
|
||||
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
|
||||
context_length=context_length,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
if self.is_hybrid is not None:
|
||||
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
|
||||
get_hybrid_layer_ids(
|
||||
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
|
||||
)
|
||||
)
|
||||
|
||||
if enable_multimodal is None:
|
||||
mm_disabled_models = [
|
||||
@@ -264,6 +277,7 @@ class ModelConfig:
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
impl=server_args.impl,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -633,3 +647,36 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def is_hybrid_model(
|
||||
model_architectures: List[str],
|
||||
hybrid_kvcache_ratio: Optional[float],
|
||||
context_length: Optional[int],
|
||||
attention_chunk_size: Optional[int],
|
||||
):
|
||||
if hybrid_kvcache_ratio is None:
|
||||
return None
|
||||
elif (
|
||||
hybrid_kvcache_ratio > 0
|
||||
and model_architectures[0] == "Llama4ForConditionalGeneration"
|
||||
and context_length > attention_chunk_size
|
||||
):
|
||||
return hybrid_kvcache_ratio
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
|
||||
if "Llama4ForConditionalGeneration" in model_architectures:
|
||||
swa_attention_layer_ids = [
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
|
||||
]
|
||||
full_attention_layer_ids = [
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
|
||||
)
|
||||
return swa_attention_layer_ids, full_attention_layer_ids
|
||||
|
||||
Reference in New Issue
Block a user