From 2695ab05375456b4a2ac625cca12b95c3b90e5ae Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Tue, 8 Apr 2025 09:11:35 -0700 Subject: [PATCH] Fix loading KV quantization scale; Enable modelopt kv cache (#4686) Co-authored-by: qingquansong --- python/sglang/srt/configs/model_config.py | 2 +- .../attention/flashattention_backend.py | 31 +++++-- .../layers/attention/flashinfer_backend.py | 20 ++-- .../srt/layers/quantization/kv_cache.py | 91 +++++++++---------- .../srt/layers/quantization/modelopt_quant.py | 29 +++++- python/sglang/srt/layers/radix_attention.py | 14 ++- python/sglang/srt/models/baichuan.py | 2 + python/sglang/srt/models/chatglm.py | 1 + python/sglang/srt/models/commandr.py | 1 + python/sglang/srt/models/dbrx.py | 1 + python/sglang/srt/models/deepseek.py | 1 + python/sglang/srt/models/deepseek_v2.py | 3 + python/sglang/srt/models/exaone.py | 1 + python/sglang/srt/models/gemma.py | 1 + python/sglang/srt/models/gemma2.py | 1 + python/sglang/srt/models/gemma3_causal.py | 1 + python/sglang/srt/models/gpt2.py | 1 + python/sglang/srt/models/gpt_bigcode.py | 1 + python/sglang/srt/models/granite.py | 1 + python/sglang/srt/models/grok.py | 1 + python/sglang/srt/models/internlm2.py | 1 + python/sglang/srt/models/llama.py | 1 + python/sglang/srt/models/minicpm.py | 1 + python/sglang/srt/models/minicpm3.py | 2 + python/sglang/srt/models/mixtral.py | 1 + python/sglang/srt/models/mixtral_quant.py | 1 + python/sglang/srt/models/mllama.py | 1 + python/sglang/srt/models/olmo.py | 1 + python/sglang/srt/models/olmo2.py | 1 + python/sglang/srt/models/olmoe.py | 1 + python/sglang/srt/models/phi3_small.py | 1 + python/sglang/srt/models/qwen.py | 1 + python/sglang/srt/models/qwen2.py | 1 + python/sglang/srt/models/qwen2_moe.py | 1 + python/sglang/srt/models/stablelm.py | 1 + python/sglang/srt/models/xverse.py | 1 + python/sglang/srt/models/xverse_moe.py | 1 + python/sglang/test/test_utils.py | 5 - 38 files changed, 151 insertions(+), 76 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 860568097..e23f089f1 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -239,7 +239,7 @@ class ModelConfig: # check if is modelopt model -- modelopt doesn't have corresponding field # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main - is_local = os.path.isdir(self.model_path) + is_local = os.path.exists(self.model_path) modelopt_quant_config = {"quant_method": "modelopt"} if not is_local: from huggingface_hub import HfApi diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index e6cb04875..78efc4332 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend): self.decode_cuda_graph_metadata = {} self.target_verify_metadata = {} self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.kv_cache_dtype = model_runner.kv_cache_dtype + self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype self.page_size = model_runner.page_size self.use_mla = ( model_runner.model_config.attention_arch == AttentionArch.MLA @@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend): if layer.sliding_window_size is not None else (-1, -1) ) + k_descale, v_descale = None, None + if self.kv_cache_dtype_str != "auto": + descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) + k_descale = layer.k_scale.expand(descale_shape) + v_descale = layer.v_scale.expand(descale_shape) + q = q.to(self.kv_cache_dtype) causal = not layer.is_cross_attention # Check if we should use local attention @@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend): causal=causal, window_size=window_size, softcap=layer.logit_cap, - k_descale=layer.k_scale, - v_descale=layer.v_scale, + k_descale=k_descale, + v_descale=v_descale, ) else: # Do absorbed multi-latent attention @@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend): softmax_scale=layer.scaling, causal=True, softcap=layer.logit_cap, - k_descale=layer.k_scale, - v_descale=layer.v_scale, + k_descale=k_descale, + v_descale=v_descale, ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) @@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend): ) causal = not layer.is_cross_attention + k_descale, v_descale = None, None + if self.kv_cache_dtype_str != "auto": + descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) + k_descale = layer.k_scale.expand(descale_shape) + v_descale = layer.v_scale.expand(descale_shape) + q = q.to(self.kv_cache_dtype) + if not self.use_mla: # Do multi-head attention @@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend): causal=causal, window_size=window_size, softcap=layer.logit_cap, - k_descale=layer.k_scale, - v_descale=layer.v_scale, + k_descale=k_descale, + v_descale=v_descale, ) else: # Do absorbed multi-latent attention @@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend): softmax_scale=layer.scaling, causal=True, softcap=layer.logit_cap, - k_descale=layer.k_scale, - v_descale=layer.v_scale, + k_descale=k_descale, + v_descale=v_descale, ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 6f7ed8523..2cb91f094 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend): self.max_context_len = model_runner.model_config.context_len self.skip_prefill = skip_prefill self.is_multimodal = model_runner.model_config.is_multimodal + self.kv_cache_dtype = model_runner.kv_cache_dtype + self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype assert not ( model_runner.sliding_window_size is not None @@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): + k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None + v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ self._get_wrapper_idx(layer) ] @@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend): assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + layer, cache_loc, k, v, k_scale, v_scale ) o = prefill_wrapper_paged.forward( @@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend): sm_scale=layer.scaling, window_left=layer.sliding_window_size, logits_soft_cap=logits_soft_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, + k_scale=k_scale, + v_scale=v_scale, ) else: o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( @@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend): if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + layer, cache_loc, k, v, k_scale, v_scale ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): + k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None + v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None decode_wrapper = self.forward_metadata.decode_wrappers[ self._get_wrapper_idx(layer) ] @@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend): assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + layer, cache_loc, k, v, k_scale, v_scale ) o = decode_wrapper.forward( @@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, + k_scale=k_scale, + v_scale=v_scale, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/layers/quantization/kv_cache.py b/python/sglang/srt/layers/quantization/kv_cache.py index 4275bca52..da6d91a9b 100644 --- a/python/sglang/srt/layers/quantization/kv_cache.py +++ b/python/sglang/srt/layers/quantization/kv_cache.py @@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils import is_hip _is_hip = is_hip() @@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) class BaseKVCacheMethod(QuantizeMethodBase): """ - Quant method that adds `_k_scale` and `_v_scale` attributes to the + Quant method that adds `k_scale` and `v_scale` attributes to the Attention layer to support loading those scaling factors from checkpoints. The k/v_scale will be used to: - quantize k/v_cache entries before saving them to the cache @@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase): # Initialize the KV cache scales to -1.0, which is an invalid value. # If the k/v_scale appears in the checkpoint, it will be # overwritten when loading weights. - layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) - layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.k_scale = torch.nn.Parameter( + torch.tensor(-1.0, dtype=torch.float32), requires_grad=False + ) + layer.v_scale = torch.nn.Parameter( + torch.tensor(-1.0, dtype=torch.float32), requires_grad=False + ) @classmethod def is_fp8_fnuz(cls) -> bool: @@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase): def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. - # No need to process kv scales after loading if we are going to - # calculate them on the fly. - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - if _is_hip and self.is_fp8_fnuz(): - k_scale *= 2 - v_scale *= 2 - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = 1.0 - v_scale = 1.0 - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - if _is_hip and self.is_fp8_fnuz(): - k_scale *= 2 - v_scale *= 2 + def process_weights_after_loading(self, layer: RadixAttention) -> None: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if _is_hip and self.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if _is_hip and self.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 - if not isinstance(k_scale, float) or not isinstance(v_scale, float): - raise ValueError( - "Only support per-tensor scaling factor " "for fp8 KV cache" - ) + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError( + "Only support per-tensor scaling factor " "for fp8 KV cache" + ) - # These are used in the final Attention.forward() - layer._k_scale.copy_(k_scale) - layer._v_scale.copy_(v_scale) - layer._k_scale_float = k_scale - layer._v_scale_float = v_scale - if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: - logger.warning( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint." - ) - - del layer.k_scale - del layer.v_scale + # These are used in the final Attention.forward() + layer.k_scale.copy_(k_scale) + layer.v_scale.copy_(v_scale) + layer.k_scale_float = k_scale + layer.v_scale_float = v_scale diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 9961054d8..eea9fa573 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( @@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, requantize_with_max_scale, ) +from sglang.srt.layers.radix_attention import RadixAttention # Initialize logger for the module logger = logging.getLogger(__name__) @@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"] class ModelOptFp8Config(QuantizationConfig): """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks.""" - def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None: + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + kv_cache_quant_method: Optional[str] = None, + exclude_modules: Optional[List[str]] = None, + ) -> None: """ Args: is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format. """ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + self.kv_cache_quant_method = kv_cache_quant_method + self.exclude_modules = exclude_modules if is_checkpoint_fp8_serialized: logger.warning( "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change." @@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") + kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get( + "kv_cache_quant_algo" + ) + exclude_modules = cls.get_from_keys(config, ["quantization"]).get( + "exclude_modules" + ) if "FP8" not in quant_method: raise ValueError( @@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig): "Check the `hf_quant_config.json` file for your model's configuration." ) - return cls(is_checkpoint_fp8_serialized=True) + return cls( + is_checkpoint_fp8_serialized=True, + kv_cache_quant_method=kv_cache_quant_method, + exclude_modules=exclude_modules, + ) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: + if self.exclude_modules and any( + module in prefix for module in self.exclude_modules + ): + return None if isinstance(layer, LinearBase): return ModelOptFp8LinearMethod(self) - if isinstance(layer, AttentionBackend): + if self.kv_cache_quant_method and isinstance(layer, RadixAttention): return ModelOptFp8KVCacheMethod(self) return None diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 69c105997..3bb30bc15 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -13,8 +13,12 @@ # ============================================================================== """Radix attention.""" +from typing import Optional + from torch import nn +from sglang.srt.layers.linear import UnquantizedLinearMethod +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -34,6 +38,7 @@ class RadixAttention(nn.Module): v_head_dim: int = -1, sliding_window_size: int = -1, is_cross_attention: bool = False, + quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_irope: bool = False, ): @@ -49,9 +54,16 @@ class RadixAttention(nn.Module): self.logit_cap = logit_cap self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention + self.use_irope = use_irope self.k_scale = None self.v_scale = None - self.use_irope = use_irope + self.k_scale_float = None + self.v_scale_float = None + self.quant_method = None + if quant_config is not None: + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + if self.quant_method is not None: + self.quant_method.create_weights(self) def forward( self, diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 578935012..6e8d3b4a3 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module): scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) else: @@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 4692a5812..9cf585a02 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -113,6 +113,7 @@ class GLMAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 7cdf0e135..ebbf8ed64 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -204,6 +204,7 @@ class CohereAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) if self.use_qk_norm: diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index b1bc79872..15cef015c 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -249,6 +249,7 @@ class DbrxAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 216aca9c2..2b963be16 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index fa6de84b2..d973f1b88 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -489,6 +489,7 @@ class DeepseekV2Attention(nn.Module): self.scaling, num_kv_heads=self.num_local_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) @@ -669,6 +670,7 @@ class DeepseekV2AttentionMLA(nn.Module): num_kv_heads=1, layer_id=layer_id, v_head_dim=self.kv_lora_rank, + quant_config=quant_config, prefix=add_prefix("attn_mqa", prefix), ) @@ -679,6 +681,7 @@ class DeepseekV2AttentionMLA(nn.Module): num_kv_heads=self.num_local_heads, layer_id=layer_id, v_head_dim=self.v_head_dim, + quant_config=quant_config, prefix=add_prefix("attn_mha", prefix), ) diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 5b301c801..430c1d58b 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, ) def forward( diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 8ab8abd4f..d8074487c 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -137,6 +137,7 @@ class GemmaAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 87cd7dbe0..9056b0b0c 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module): if use_sliding_window else None ), + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index d9e0293b7..e34715571 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module): # Module must also define `get_attention_sliding_window_size` to correctly initialize # attention backend in `ForwardBatch`. sliding_window_size=self.sliding_window, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 15374afaa..1ec33406f 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -78,6 +78,7 @@ class GPT2Attention(nn.Module): scaling=self.scale, num_kv_heads=total_num_heads, layer_id=layer_id, + quant_config=quant_config, ) def forward( diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 631da1298..f49a5d304 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module): scaling=self.scale, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index 086a8fb82..26fccc48d 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -158,6 +158,7 @@ class GraniteAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 2ef25daef..a8cde8e09 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -215,6 +215,7 @@ class Grok1Attention(nn.Module): num_kv_heads=self.num_kv_heads, layer_id=layer_id, logit_cap=logit_cap, + quant_config=quant_config, ) def forward( diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index fe39dd1a4..28f51cef4 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -145,6 +145,7 @@ class InternLM2Attention(nn.Module): self.scaling, self.num_kv_heads, layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index d53100d2c..57707c349 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -170,6 +170,7 @@ class LlamaAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index f7133bcce..2df170f38 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -146,6 +146,7 @@ class MiniCPMAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index f1c08c5fe..eae2bf007 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -192,6 +192,7 @@ class MiniCPM3Attention(nn.Module): self.scaling, num_kv_heads=self.num_local_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) @@ -343,6 +344,7 @@ class MiniCPM3AttentionMLA(nn.Module): num_kv_heads=1, layer_id=layer_id, v_head_dim=self.kv_lora_rank, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 058f96fdd..8d5a03f0a 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -169,6 +169,7 @@ class MixtralAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index c3ba17bc9..cfb23cdf3 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -232,6 +232,7 @@ class MixtralAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 8347ca2e7..f8bd9b9b6 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -535,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module): self.num_local_key_value_heads, layer_id=layer_id, is_cross_attention=True, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 686cb01ac..08239374d 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -93,6 +93,7 @@ class OlmoAttention(nn.Module): self.scaling, num_kv_heads=self.num_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index 716ae99e4..75834e6fb 100644 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -118,6 +118,7 @@ class Olmo2Attention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index df3bd0dbf..612120fe9 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -170,6 +170,7 @@ class OlmoeAttention(nn.Module): self.scaling, layer_id=layer_id, num_kv_heads=self.num_kv_heads, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index d99d09c06..c59d296a6 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -202,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module): self.scale, num_kv_heads=self.num_kv_heads_per_partion, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index cd94a9103..f0660f62d 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -133,6 +133,7 @@ class QWenAttention(nn.Module): self.scaling, num_kv_heads=self.num_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 6fad7a488..b6646b5fb 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -154,6 +154,7 @@ class Qwen2Attention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index fa00b35e1..4cbb0df4a 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -231,6 +231,7 @@ class Qwen2MoeAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 45ac90c97..1566893ee 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -149,6 +149,7 @@ class StablelmAttention(nn.Module): self.scaling, num_kv_heads=self.num_key_value_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 2162f7a44..f84755b03 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -153,6 +153,7 @@ class XverseAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index a7c79ec8c..0ea9ed950 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -252,6 +252,7 @@ class XverseAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + quant_config=quant_config, prefix=add_prefix("attn", prefix), ) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index fe876c960..7d68dcf37 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -37,11 +37,6 @@ DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = ( DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = ( "nvidia/Llama-3.1-8B-Instruct-FP8" ) -# TODO(yundai424): right now specifying to an older revision since the latest one -# carries kv cache quantization which doesn't work yet -DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION = ( - "13858565416dbdc0b4e7a4a677fadfbd5b9e5bb9" -) DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"