Add support for nvidia modelopt fp8 kv cache (#3223)

This commit is contained in:
Zhiyu
2025-02-21 15:04:58 -08:00
committed by GitHub
parent 20b765a26e
commit c66b2c9cf1
4 changed files with 65 additions and 2 deletions

View File

@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
requantize_with_max_scale,
)
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
if isinstance(layer, AttentionBackend):
return ModelOptFp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
)
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__(quant_config)