Fix loading KV quantization scale; Enable modelopt kv cache (#4686)
Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
@@ -239,7 +239,7 @@ class ModelConfig:
|
|||||||
# check if is modelopt model -- modelopt doesn't have corresponding field
|
# check if is modelopt model -- modelopt doesn't have corresponding field
|
||||||
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
|
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
|
||||||
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
||||||
is_local = os.path.isdir(self.model_path)
|
is_local = os.path.exists(self.model_path)
|
||||||
modelopt_quant_config = {"quant_method": "modelopt"}
|
modelopt_quant_config = {"quant_method": "modelopt"}
|
||||||
if not is_local:
|
if not is_local:
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
|||||||
@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
self.target_verify_metadata = {}
|
self.target_verify_metadata = {}
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||||
|
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
||||||
self.page_size = model_runner.page_size
|
self.page_size = model_runner.page_size
|
||||||
self.use_mla = (
|
self.use_mla = (
|
||||||
model_runner.model_config.attention_arch == AttentionArch.MLA
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||||
@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if layer.sliding_window_size is not None
|
if layer.sliding_window_size is not None
|
||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
|
k_descale, v_descale = None, None
|
||||||
|
if self.kv_cache_dtype_str != "auto":
|
||||||
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||||
|
k_descale = layer.k_scale.expand(descale_shape)
|
||||||
|
v_descale = layer.v_scale.expand(descale_shape)
|
||||||
|
q = q.to(self.kv_cache_dtype)
|
||||||
causal = not layer.is_cross_attention
|
causal = not layer.is_cross_attention
|
||||||
|
|
||||||
# Check if we should use local attention
|
# Check if we should use local attention
|
||||||
@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
causal=causal,
|
causal=causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=layer.k_scale,
|
k_descale=k_descale,
|
||||||
v_descale=layer.v_scale,
|
v_descale=v_descale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Do absorbed multi-latent attention
|
# Do absorbed multi-latent attention
|
||||||
@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=layer.k_scale,
|
k_descale=k_descale,
|
||||||
v_descale=layer.v_scale,
|
v_descale=v_descale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
causal = not layer.is_cross_attention
|
causal = not layer.is_cross_attention
|
||||||
|
|
||||||
|
k_descale, v_descale = None, None
|
||||||
|
if self.kv_cache_dtype_str != "auto":
|
||||||
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||||
|
k_descale = layer.k_scale.expand(descale_shape)
|
||||||
|
v_descale = layer.v_scale.expand(descale_shape)
|
||||||
|
q = q.to(self.kv_cache_dtype)
|
||||||
|
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
# Do multi-head attention
|
# Do multi-head attention
|
||||||
|
|
||||||
@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
causal=causal,
|
causal=causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=layer.k_scale,
|
k_descale=k_descale,
|
||||||
v_descale=layer.v_scale,
|
v_descale=v_descale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Do absorbed multi-latent attention
|
# Do absorbed multi-latent attention
|
||||||
@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=True,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=layer.k_scale,
|
k_descale=k_descale,
|
||||||
v_descale=layer.v_scale,
|
v_descale=v_descale,
|
||||||
)
|
)
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.skip_prefill = skip_prefill
|
self.skip_prefill = skip_prefill
|
||||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||||
|
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||||
|
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
||||||
|
|
||||||
assert not (
|
assert not (
|
||||||
model_runner.sliding_window_size is not None
|
model_runner.sliding_window_size is not None
|
||||||
@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
|
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||||
|
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||||
self._get_wrapper_idx(layer)
|
self._get_wrapper_idx(layer)
|
||||||
]
|
]
|
||||||
@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
layer, cache_loc, k, v, k_scale, v_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
o = prefill_wrapper_paged.forward(
|
o = prefill_wrapper_paged.forward(
|
||||||
@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
window_left=layer.sliding_window_size,
|
window_left=layer.sliding_window_size,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
k_scale=layer.k_scale,
|
k_scale=k_scale,
|
||||||
v_scale=layer.v_scale,
|
v_scale=v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
layer, cache_loc, k, v, k_scale, v_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
|
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||||
|
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||||
decode_wrapper = self.forward_metadata.decode_wrappers[
|
decode_wrapper = self.forward_metadata.decode_wrappers[
|
||||||
self._get_wrapper_idx(layer)
|
self._get_wrapper_idx(layer)
|
||||||
]
|
]
|
||||||
@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
layer, cache_loc, k, v, k_scale, v_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
o = decode_wrapper.forward(
|
o = decode_wrapper.forward(
|
||||||
@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=layer.logit_cap,
|
logits_soft_cap=layer.logit_cap,
|
||||||
k_scale=layer.k_scale,
|
k_scale=k_scale,
|
||||||
v_scale=layer.v_scale,
|
v_scale=v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||||
"""
|
"""
|
||||||
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
Quant method that adds `k_scale` and `v_scale` attributes to the
|
||||||
Attention layer to support loading those scaling factors from checkpoints.
|
Attention layer to support loading those scaling factors from checkpoints.
|
||||||
The k/v_scale will be used to:
|
The k/v_scale will be used to:
|
||||||
- quantize k/v_cache entries before saving them to the cache
|
- quantize k/v_cache entries before saving them to the cache
|
||||||
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
||||||
# If the k/v_scale appears in the checkpoint, it will be
|
# If the k/v_scale appears in the checkpoint, it will be
|
||||||
# overwritten when loading weights.
|
# overwritten when loading weights.
|
||||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
layer.k_scale = torch.nn.Parameter(
|
||||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.v_scale = torch.nn.Parameter(
|
||||||
|
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_fp8_fnuz(cls) -> bool:
|
def is_fp8_fnuz(cls) -> bool:
|
||||||
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: RadixAttention) -> None:
|
||||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||||
# regardless whether the kv-scale is available in the checkpoint.
|
# We prefer to use separate k_scale and v_scale if present
|
||||||
# No need to process kv scales after loading if we are going to
|
k_scale = layer.k_scale.to("cpu").tolist()
|
||||||
# calculate them on the fly.
|
v_scale = layer.v_scale.to("cpu").tolist()
|
||||||
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
if _is_hip and self.is_fp8_fnuz():
|
||||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
k_scale *= 2
|
||||||
# We prefer to use separate k_scale and v_scale if present
|
v_scale *= 2
|
||||||
k_scale = layer.k_scale.to("cpu").tolist()
|
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||||
v_scale = layer.v_scale.to("cpu").tolist()
|
# If no scales were loaded (both scales are invalid negative
|
||||||
if _is_hip and self.is_fp8_fnuz():
|
# values), use the default value of 1.0
|
||||||
k_scale *= 2
|
k_scale = 1.0
|
||||||
v_scale *= 2
|
v_scale = 1.0
|
||||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
else:
|
||||||
# If no scales were loaded (both scales are invalid negative
|
# If we find a single kv_scale in the checkpoint, we remap
|
||||||
# values), use the default value of 1.0
|
# kv_scale to k_scale during weight loading, and duplicate
|
||||||
k_scale = 1.0
|
# k_scale to v_scale here
|
||||||
v_scale = 1.0
|
assert layer.k_scale > 0.0
|
||||||
else:
|
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||||
# If we find a single kv_scale in the checkpoint, we remap
|
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||||
# kv_scale to k_scale during weight loading, and duplicate
|
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||||
# k_scale to v_scale here
|
if _is_hip and self.is_fp8_fnuz():
|
||||||
assert layer.k_scale > 0.0
|
k_scale *= 2
|
||||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
v_scale *= 2
|
||||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
|
||||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
|
||||||
if _is_hip and self.is_fp8_fnuz():
|
|
||||||
k_scale *= 2
|
|
||||||
v_scale *= 2
|
|
||||||
|
|
||||||
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
# These are used in the final Attention.forward()
|
# These are used in the final Attention.forward()
|
||||||
layer._k_scale.copy_(k_scale)
|
layer.k_scale.copy_(k_scale)
|
||||||
layer._v_scale.copy_(v_scale)
|
layer.v_scale.copy_(v_scale)
|
||||||
layer._k_scale_float = k_scale
|
layer.k_scale_float = k_scale
|
||||||
layer._v_scale_float = v_scale
|
layer.v_scale_float = v_scale
|
||||||
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
|
||||||
logger.warning(
|
|
||||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
|
||||||
"may cause accuracy issues. Please make sure k/v_scale "
|
|
||||||
"scaling factors are available in the fp8 checkpoint."
|
|
||||||
)
|
|
||||||
|
|
||||||
del layer.k_scale
|
|
||||||
del layer.v_scale
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
|
||||||
# Initialize logger for the module
|
# Initialize logger for the module
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"]
|
|||||||
class ModelOptFp8Config(QuantizationConfig):
|
class ModelOptFp8Config(QuantizationConfig):
|
||||||
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
||||||
|
|
||||||
def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
is_checkpoint_fp8_serialized: bool = False,
|
||||||
|
kv_cache_quant_method: Optional[str] = None,
|
||||||
|
exclude_modules: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
||||||
"""
|
"""
|
||||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||||
|
self.kv_cache_quant_method = kv_cache_quant_method
|
||||||
|
self.exclude_modules = exclude_modules
|
||||||
if is_checkpoint_fp8_serialized:
|
if is_checkpoint_fp8_serialized:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
||||||
@@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
||||||
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
||||||
|
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
||||||
|
"kv_cache_quant_algo"
|
||||||
|
)
|
||||||
|
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
|
||||||
|
"exclude_modules"
|
||||||
|
)
|
||||||
|
|
||||||
if "FP8" not in quant_method:
|
if "FP8" not in quant_method:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
"Check the `hf_quant_config.json` file for your model's configuration."
|
"Check the `hf_quant_config.json` file for your model's configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(is_checkpoint_fp8_serialized=True)
|
return cls(
|
||||||
|
is_checkpoint_fp8_serialized=True,
|
||||||
|
kv_cache_quant_method=kv_cache_quant_method,
|
||||||
|
exclude_modules=exclude_modules,
|
||||||
|
)
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
if self.exclude_modules and any(
|
||||||
|
module in prefix for module in self.exclude_modules
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return ModelOptFp8LinearMethod(self)
|
return ModelOptFp8LinearMethod(self)
|
||||||
if isinstance(layer, AttentionBackend):
|
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -13,8 +13,12 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Radix attention."""
|
"""Radix attention."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
|
||||||
@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
|
|||||||
v_head_dim: int = -1,
|
v_head_dim: int = -1,
|
||||||
sliding_window_size: int = -1,
|
sliding_window_size: int = -1,
|
||||||
is_cross_attention: bool = False,
|
is_cross_attention: bool = False,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
):
|
):
|
||||||
@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
|
|||||||
self.logit_cap = logit_cap
|
self.logit_cap = logit_cap
|
||||||
self.sliding_window_size = sliding_window_size or -1
|
self.sliding_window_size = sliding_window_size or -1
|
||||||
self.is_cross_attention = is_cross_attention
|
self.is_cross_attention = is_cross_attention
|
||||||
|
self.use_irope = use_irope
|
||||||
self.k_scale = None
|
self.k_scale = None
|
||||||
self.v_scale = None
|
self.v_scale = None
|
||||||
self.use_irope = use_irope
|
self.k_scale_float = None
|
||||||
|
self.v_scale_float = None
|
||||||
|
self.quant_method = None
|
||||||
|
if quant_config is not None:
|
||||||
|
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||||
|
if self.quant_method is not None:
|
||||||
|
self.quant_method.create_weights(self)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
|
|||||||
scaling,
|
scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ class GLMAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -204,6 +204,7 @@ class CohereAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ class DbrxAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -489,6 +489,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_local_heads,
|
num_kv_heads=self.num_local_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -669,6 +670,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
num_kv_heads=1,
|
num_kv_heads=1,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
v_head_dim=self.kv_lora_rank,
|
v_head_dim=self.kv_lora_rank,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn_mqa", prefix),
|
prefix=add_prefix("attn_mqa", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -679,6 +681,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
num_kv_heads=self.num_local_heads,
|
num_kv_heads=self.num_local_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
v_head_dim=self.v_head_dim,
|
v_head_dim=self.v_head_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn_mha", prefix),
|
prefix=add_prefix("attn_mha", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module):
|
|||||||
if use_sliding_window
|
if use_sliding_window
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
|
|||||||
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
||||||
# attention backend in `ForwardBatch`.
|
# attention backend in `ForwardBatch`.
|
||||||
sliding_window_size=self.sliding_window,
|
sliding_window_size=self.sliding_window,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module):
|
|||||||
scaling=self.scale,
|
scaling=self.scale,
|
||||||
num_kv_heads=total_num_heads,
|
num_kv_heads=total_num_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
scaling=self.scale,
|
scaling=self.scale,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module):
|
|||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ class InternLM2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
layer_id,
|
layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ class LlamaAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ class MiniCPMAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_local_heads,
|
num_kv_heads=self.num_local_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -343,6 +344,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|||||||
num_kv_heads=1,
|
num_kv_heads=1,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
v_head_dim=self.kv_lora_rank,
|
v_head_dim=self.kv_lora_rank,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ class MixtralAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -232,6 +232,7 @@ class MixtralAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -535,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
self.num_local_key_value_heads,
|
self.num_local_key_value_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
is_cross_attention=True,
|
is_cross_attention=True,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ class OlmoAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_heads,
|
num_kv_heads=self.num_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ class Olmo2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -170,6 +170,7 @@ class OlmoeAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -202,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
|||||||
self.scale,
|
self.scale,
|
||||||
num_kv_heads=self.num_kv_heads_per_partion,
|
num_kv_heads=self.num_kv_heads_per_partion,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ class QWenAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_heads,
|
num_kv_heads=self.num_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ class Qwen2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ class StablelmAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_key_value_heads,
|
num_kv_heads=self.num_key_value_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ class XverseAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -252,6 +252,7 @@ class XverseAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("attn", prefix),
|
prefix=add_prefix("attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -37,11 +37,6 @@ DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
|
|||||||
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = (
|
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = (
|
||||||
"nvidia/Llama-3.1-8B-Instruct-FP8"
|
"nvidia/Llama-3.1-8B-Instruct-FP8"
|
||||||
)
|
)
|
||||||
# TODO(yundai424): right now specifying to an older revision since the latest one
|
|
||||||
# carries kv cache quantization which doesn't work yet
|
|
||||||
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION = (
|
|
||||||
"13858565416dbdc0b4e7a4a677fadfbd5b9e5bb9"
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|||||||
Reference in New Issue
Block a user