Fix loading KV quantization scale; Enable modelopt kv cache (#4686)

Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
Yun Dai
2025-04-08 09:11:35 -07:00
committed by GitHub
parent 88d6fd9a11
commit 2695ab0537
38 changed files with 151 additions and 76 deletions

View File

@@ -13,8 +13,12 @@
# ==============================================================================
"""Radix attention."""
from typing import Optional
from torch import nn
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
v_head_dim: int = -1,
sliding_window_size: int = -1,
is_cross_attention: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_irope: bool = False,
):
@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
self.use_irope = use_irope
self.k_scale = None
self.v_scale = None
self.use_irope = use_irope
self.k_scale_float = None
self.v_scale_float = None
self.quant_method = None
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
if self.quant_method is not None:
self.quant_method.create_weights(self)
def forward(
self,