Add support for nvidia modelopt fp8 kv cache (#3223)
This commit is contained in:
@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return ModelOptFp8LinearMethod(self)
|
||||||
|
if isinstance(layer, AttentionBackend):
|
||||||
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
return []
|
return []
|
||||||
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
"""
|
||||||
|
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: ModelOptFp8Config):
|
||||||
|
super().__init__(quant_config)
|
||||||
|
|||||||
@@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
return remapped_name
|
return remapped_name
|
||||||
|
|
||||||
possible_scale_names = [".k_scale", ".v_scale"]
|
possible_scale_names = [".k_scale", ".v_scale"]
|
||||||
|
modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"]
|
||||||
for scale_name in possible_scale_names:
|
for scale_name in possible_scale_names:
|
||||||
if name.endswith(scale_name):
|
if name.endswith(scale_name):
|
||||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
# Check and remap the name based on modelopt scale names
|
||||||
|
if any(
|
||||||
|
modelopt_scale_name in name
|
||||||
|
for modelopt_scale_name in modelopt_scale_names
|
||||||
|
):
|
||||||
|
remapped_name = name.replace(
|
||||||
|
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
||||||
|
f".self_attn.attn{scale_name}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||||
if remapped_name not in params_dict:
|
if remapped_name not in params_dict:
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
kv_cache_scales_loader,
|
kv_cache_scales_loader,
|
||||||
|
maybe_remap_kv_scale_name,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import make_layers
|
from sglang.srt.utils import make_layers
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
@@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
# Handle FP8 kv-scale remapping
|
||||||
|
if "scale" in name:
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
|
|||||||
29
test/srt/test_modelopt_fp8kvcache.py
Normal file
29
test/srt/test_modelopt_fp8kvcache.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||||
|
ModelOptFp8Config,
|
||||||
|
ModelOptFp8KVCacheMethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelOptFp8KVCacheMethod(unittest.TestCase):
|
||||||
|
def test_kv_cache_method_initialization(self):
|
||||||
|
"""Test that ModelOptFp8KVCacheMethod can be instantiated and
|
||||||
|
inherits from BaseKVCacheMethod."""
|
||||||
|
# Create a ModelOptFp8Config object
|
||||||
|
quant_config = ModelOptFp8Config(is_checkpoint_fp8_serialized=True)
|
||||||
|
|
||||||
|
# Instantiate the KV cache method
|
||||||
|
kv_cache_method = ModelOptFp8KVCacheMethod(quant_config)
|
||||||
|
|
||||||
|
# Check inheritance
|
||||||
|
self.assertIsInstance(kv_cache_method, BaseKVCacheMethod)
|
||||||
|
|
||||||
|
# Check that the quant_config is stored
|
||||||
|
self.assertEqual(kv_cache_method.quant_config, quant_config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user