Fix loading KV quantization scale; Enable modelopt kv cache (#4686)
Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"""
|
||||
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
||||
Quant method that adds `k_scale` and `v_scale` attributes to the
|
||||
Attention layer to support loading those scaling factors from checkpoints.
|
||||
The k/v_scale will be used to:
|
||||
- quantize k/v_cache entries before saving them to the cache
|
||||
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
||||
# If the k/v_scale appears in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.k_scale = torch.nn.Parameter(
|
||||
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.v_scale = torch.nn.Parameter(
|
||||
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
# calculate them on the fly.
|
||||
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
def process_weights_after_loading(self, layer: RadixAttention) -> None:
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
|
||||
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
||||
raise ValueError(
|
||||
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
||||
)
|
||||
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
||||
raise ValueError(
|
||||
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
||||
)
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale.copy_(k_scale)
|
||||
layer._v_scale.copy_(v_scale)
|
||||
layer._k_scale_float = k_scale
|
||||
layer._v_scale_float = v_scale
|
||||
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
||||
logger.warning(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||
"may cause accuracy issues. Please make sure k/v_scale "
|
||||
"scaling factors are available in the fp8 checkpoint."
|
||||
)
|
||||
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
# These are used in the final Attention.forward()
|
||||
layer.k_scale.copy_(k_scale)
|
||||
layer.v_scale.copy_(v_scale)
|
||||
layer.k_scale_float = k_scale
|
||||
layer.v_scale_float = v_scale
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
convert_to_channelwise,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"]
|
||||
class ModelOptFp8Config(QuantizationConfig):
|
||||
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
||||
|
||||
def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
kv_cache_quant_method: Optional[str] = None,
|
||||
exclude_modules: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
||||
"""
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
self.exclude_modules = exclude_modules
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning(
|
||||
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
||||
@@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
||||
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
||||
"kv_cache_quant_algo"
|
||||
)
|
||||
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
|
||||
"exclude_modules"
|
||||
)
|
||||
|
||||
if "FP8" not in quant_method:
|
||||
raise ValueError(
|
||||
@@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
"Check the `hf_quant_config.json` file for your model's configuration."
|
||||
)
|
||||
|
||||
return cls(is_checkpoint_fp8_serialized=True)
|
||||
return cls(
|
||||
is_checkpoint_fp8_serialized=True,
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=exclude_modules,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if self.exclude_modules and any(
|
||||
module in prefix for module in self.exclude_modules
|
||||
):
|
||||
return None
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
if isinstance(layer, AttentionBackend):
|
||||
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user