Add support for nvidia modelopt fp8 kv cache (#3223)
This commit is contained in:
@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
if isinstance(layer, AttentionBackend):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
Reference in New Issue
Block a user