diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 3e5f996ed..38f2e7b2d 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale, ) +from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( @@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None + + if isinstance(layer, LinearBase): + return ModelOptFp8LinearMethod(self) + if isinstance(layer, AttentionBackend): + return ModelOptFp8KVCacheMethod(self) + + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase): bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, ) + + +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): + """ + Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + super().__init__(quant_config) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index c07a346f4..822e28844 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return remapped_name possible_scale_names = [".k_scale", ".v_scale"] + modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"] for scale_name in possible_scale_names: if name.endswith(scale_name): - remapped_name = name.replace(scale_name, f".attn{scale_name}") + # Check and remap the name based on modelopt scale names + if any( + modelopt_scale_name in name + for modelopt_scale_name in modelopt_scale_names + ): + remapped_name = name.replace( + f".self_attn.{scale_name[1]}_proj{scale_name}", + f".self_attn.attn{scale_name}", + ) + else: + remapped_name = name.replace(scale_name, f".attn{scale_name}") if remapped_name not in params_dict: print_warning_once( f"Found {scale_name} in the checkpoint (e.g. {name}), " diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index cd37535f4..9cb13e944 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, + maybe_remap_kv_scale_name, ) from sglang.srt.utils import make_layers from sglang.utils import get_exception_traceback @@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module): continue if name.startswith("model.vision_tower") and name not in params_dict: continue + # Handle FP8 kv-scale remapping + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/test/srt/test_modelopt_fp8kvcache.py b/test/srt/test_modelopt_fp8kvcache.py new file mode 100644 index 000000000..da6bb3651 --- /dev/null +++ b/test/srt/test_modelopt_fp8kvcache.py @@ -0,0 +1,29 @@ +import unittest + +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod + +from sglang.srt.layers.quantization.modelopt_quant import ( + ModelOptFp8Config, + ModelOptFp8KVCacheMethod, +) + + +class TestModelOptFp8KVCacheMethod(unittest.TestCase): + def test_kv_cache_method_initialization(self): + """Test that ModelOptFp8KVCacheMethod can be instantiated and + inherits from BaseKVCacheMethod.""" + # Create a ModelOptFp8Config object + quant_config = ModelOptFp8Config(is_checkpoint_fp8_serialized=True) + + # Instantiate the KV cache method + kv_cache_method = ModelOptFp8KVCacheMethod(quant_config) + + # Check inheritance + self.assertIsInstance(kv_cache_method, BaseKVCacheMethod) + + # Check that the quant_config is stored + self.assertEqual(kv_cache_method.quant_config, quant_config) + + +if __name__ == "__main__": + unittest.main()