Add support for nvidia modelopt fp8 kv cache (#3223)

This commit is contained in:
Zhiyu
2025-02-21 15:04:58 -08:00
committed by GitHub
parent 20b765a26e
commit c66b2c9cf1
4 changed files with 65 additions and 2 deletions

View File

@@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback
@@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module):
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
# Handle FP8 kv-scale remapping
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: