Fix loading KV quantization scale; Enable modelopt kv cache (#4686)
Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
@@ -239,7 +239,7 @@ class ModelConfig:
|
||||
# check if is modelopt model -- modelopt doesn't have corresponding field
|
||||
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
|
||||
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
||||
is_local = os.path.isdir(self.model_path)
|
||||
is_local = os.path.exists(self.model_path)
|
||||
modelopt_quant_config = {"quant_method": "modelopt"}
|
||||
if not is_local:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.target_verify_metadata = {}
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
||||
self.page_size = model_runner.page_size
|
||||
self.use_mla = (
|
||||
model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if layer.sliding_window_size is not None
|
||||
else (-1, -1)
|
||||
)
|
||||
k_descale, v_descale = None, None
|
||||
if self.kv_cache_dtype_str != "auto":
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
v_descale = layer.v_scale.expand(descale_shape)
|
||||
q = q.to(self.kv_cache_dtype)
|
||||
causal = not layer.is_cross_attention
|
||||
|
||||
# Check if we should use local attention
|
||||
@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
causal = not layer.is_cross_attention
|
||||
|
||||
k_descale, v_descale = None, None
|
||||
if self.kv_cache_dtype_str != "auto":
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
v_descale = layer.v_scale.expand(descale_shape)
|
||||
q = q.to(self.kv_cache_dtype)
|
||||
|
||||
if not self.use_mla:
|
||||
# Do multi-head attention
|
||||
|
||||
@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
|
||||
@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.skip_prefill = skip_prefill
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
||||
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
decode_wrapper = self.forward_metadata.decode_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"""
|
||||
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
||||
Quant method that adds `k_scale` and `v_scale` attributes to the
|
||||
Attention layer to support loading those scaling factors from checkpoints.
|
||||
The k/v_scale will be used to:
|
||||
- quantize k/v_cache entries before saving them to the cache
|
||||
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
||||
# If the k/v_scale appears in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.k_scale = torch.nn.Parameter(
|
||||
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.v_scale = torch.nn.Parameter(
|
||||
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_fp8_fnuz(cls) -> bool:
|
||||
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
# No need to process kv scales after loading if we are going to
|
||||
# calculate them on the fly.
|
||||
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
def process_weights_after_loading(self, layer: RadixAttention) -> None:
|
||||
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
# If no scales were loaded (both scales are invalid negative
|
||||
# values), use the default value of 1.0
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
else:
|
||||
# If we find a single kv_scale in the checkpoint, we remap
|
||||
# kv_scale to k_scale during weight loading, and duplicate
|
||||
# k_scale to v_scale here
|
||||
assert layer.k_scale > 0.0
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if _is_hip and self.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
|
||||
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
||||
raise ValueError(
|
||||
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
||||
)
|
||||
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
|
||||
raise ValueError(
|
||||
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
||||
)
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale.copy_(k_scale)
|
||||
layer._v_scale.copy_(v_scale)
|
||||
layer._k_scale_float = k_scale
|
||||
layer._v_scale_float = v_scale
|
||||
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
||||
logger.warning(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||
"may cause accuracy issues. Please make sure k/v_scale "
|
||||
"scaling factors are available in the fp8 checkpoint."
|
||||
)
|
||||
|
||||
del layer.k_scale
|
||||
del layer.v_scale
|
||||
# These are used in the final Attention.forward()
|
||||
layer.k_scale.copy_(k_scale)
|
||||
layer.v_scale.copy_(v_scale)
|
||||
layer.k_scale_float = k_scale
|
||||
layer.v_scale_float = v_scale
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
convert_to_channelwise,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"]
|
||||
class ModelOptFp8Config(QuantizationConfig):
|
||||
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
|
||||
|
||||
def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
kv_cache_quant_method: Optional[str] = None,
|
||||
exclude_modules: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
|
||||
"""
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
self.exclude_modules = exclude_modules
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning(
|
||||
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
|
||||
@@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
||||
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
||||
"kv_cache_quant_algo"
|
||||
)
|
||||
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
|
||||
"exclude_modules"
|
||||
)
|
||||
|
||||
if "FP8" not in quant_method:
|
||||
raise ValueError(
|
||||
@@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
"Check the `hf_quant_config.json` file for your model's configuration."
|
||||
)
|
||||
|
||||
return cls(is_checkpoint_fp8_serialized=True)
|
||||
return cls(
|
||||
is_checkpoint_fp8_serialized=True,
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=exclude_modules,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if self.exclude_modules and any(
|
||||
module in prefix for module in self.exclude_modules
|
||||
):
|
||||
return None
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
if isinstance(layer, AttentionBackend):
|
||||
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
|
||||
return None
|
||||
|
||||
@@ -13,8 +13,12 @@
|
||||
# ==============================================================================
|
||||
"""Radix attention."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
|
||||
v_head_dim: int = -1,
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_irope: bool = False,
|
||||
):
|
||||
@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.use_irope = use_irope
|
||||
self.k_scale = None
|
||||
self.v_scale = None
|
||||
self.use_irope = use_irope
|
||||
self.k_scale_float = None
|
||||
self.v_scale_float = None
|
||||
self.quant_method = None
|
||||
if quant_config is not None:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if self.quant_method is not None:
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
|
||||
scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
else:
|
||||
@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -113,6 +113,7 @@ class GLMAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -204,6 +204,7 @@ class CohereAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
if self.use_qk_norm:
|
||||
|
||||
@@ -249,6 +249,7 @@ class DbrxAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -489,6 +489,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
@@ -669,6 +670,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
num_kv_heads=1,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.kv_lora_rank,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn_mqa", prefix),
|
||||
)
|
||||
|
||||
@@ -679,6 +681,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
num_kv_heads=self.num_local_heads,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.v_head_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn_mha", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module):
|
||||
if use_sliding_window
|
||||
else None
|
||||
),
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
|
||||
# Module must also define `get_attention_sliding_window_size` to correctly initialize
|
||||
# attention backend in `ForwardBatch`.
|
||||
sliding_window_size=self.sliding_window,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module):
|
||||
scaling=self.scale,
|
||||
num_kv_heads=total_num_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
scaling=self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
logit_cap=logit_cap,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -145,6 +145,7 @@ class InternLM2Attention(nn.Module):
|
||||
self.scaling,
|
||||
self.num_kv_heads,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -170,6 +170,7 @@ class LlamaAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -146,6 +146,7 @@ class MiniCPMAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -192,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
@@ -343,6 +344,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
num_kv_heads=1,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.kv_lora_rank,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -169,6 +169,7 @@ class MixtralAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -232,6 +232,7 @@ class MixtralAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -535,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.num_local_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
is_cross_attention=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ class OlmoAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ class Olmo2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -170,6 +170,7 @@ class OlmoeAttention(nn.Module):
|
||||
self.scaling,
|
||||
layer_id=layer_id,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -202,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads_per_partion,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ class QWenAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -154,6 +154,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -231,6 +231,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -149,6 +149,7 @@ class StablelmAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -153,6 +153,7 @@ class XverseAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -252,6 +252,7 @@ class XverseAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -37,11 +37,6 @@ DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
|
||||
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = (
|
||||
"nvidia/Llama-3.1-8B-Instruct-FP8"
|
||||
)
|
||||
# TODO(yundai424): right now specifying to an older revision since the latest one
|
||||
# carries kv cache quantization which doesn't work yet
|
||||
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION = (
|
||||
"13858565416dbdc0b4e7a4a677fadfbd5b9e5bb9"
|
||||
)
|
||||
|
||||
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
Reference in New Issue
Block a user