Fix loading KV quantization scale; Enable modelopt kv cache (#4686)

Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
Yun Dai
2025-04-08 09:11:35 -07:00
committed by GitHub
parent 88d6fd9a11
commit 2695ab0537
38 changed files with 151 additions and 76 deletions

View File

@@ -239,7 +239,7 @@ class ModelConfig:
# check if is modelopt model -- modelopt doesn't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
is_local = os.path.isdir(self.model_path)
is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local:
from huggingface_hub import HfApi

View File

@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend):
self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA
@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None
else (-1, -1)
)
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
causal = not layer.is_cross_attention
# Check if we should use local attention
@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend):
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend):
)
causal = not layer.is_cross_attention
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
if not self.use_mla:
# Do multi-head attention
@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend):
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

View File

@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
assert not (
model_runner.sliding_window_size is not None
@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
o = prefill_wrapper_paged.forward(
@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
k_scale=k_scale,
v_scale=v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
o = decode_wrapper.forward(
@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
k_scale=k_scale,
v_scale=v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)

View File

@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_hip
_is_hip = is_hip()
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
class BaseKVCacheMethod(QuantizeMethodBase):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
Quant method that adds `k_scale` and `v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.k_scale = torch.nn.Parameter(
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
)
layer.v_scale = torch.nn.Parameter(
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
)
@classmethod
def is_fp8_fnuz(cls) -> bool:
@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
def process_weights_after_loading(self, layer: RadixAttention) -> None:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
raise ValueError(
"Only support per-tensor scaling factor " "for fp8 KV cache"
)
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
raise ValueError(
"Only support per-tensor scaling factor " "for fp8 KV cache"
)
# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
layer._k_scale_float = k_scale
layer._v_scale_float = v_scale
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
logger.warning(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint."
)
del layer.k_scale
del layer.v_scale
# These are used in the final Attention.forward()
layer.k_scale.copy_(k_scale)
layer.v_scale.copy_(v_scale)
layer.k_scale_float = k_scale
layer.v_scale_float = v_scale

View File

@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import (
convert_to_channelwise,
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
# Initialize logger for the module
logger = logging.getLogger(__name__)
@@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"]
class ModelOptFp8Config(QuantizationConfig):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: Optional[str] = None,
exclude_modules: Optional[List[str]] = None,
) -> None:
"""
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
@@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
"kv_cache_quant_algo"
)
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
"exclude_modules"
)
if "FP8" not in quant_method:
raise ValueError(
@@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig):
"Check the `hf_quant_config.json` file for your model's configuration."
)
return cls(is_checkpoint_fp8_serialized=True)
return cls(
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if self.exclude_modules and any(
module in prefix for module in self.exclude_modules
):
return None
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
if isinstance(layer, AttentionBackend):
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
return None

View File

@@ -13,8 +13,12 @@
# ==============================================================================
"""Radix attention."""
from typing import Optional
from torch import nn
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
v_head_dim: int = -1,
sliding_window_size: int = -1,
is_cross_attention: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_irope: bool = False,
):
@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
self.use_irope = use_irope
self.k_scale = None
self.v_scale = None
self.use_irope = use_irope
self.k_scale_float = None
self.v_scale_float = None
self.quant_method = None
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
if self.quant_method is not None:
self.quant_method.create_weights(self)
def forward(
self,

View File

@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
else:
@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -113,6 +113,7 @@ class GLMAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -204,6 +204,7 @@ class CohereAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
if self.use_qk_norm:

View File

@@ -249,6 +249,7 @@ class DbrxAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -489,6 +489,7 @@ class DeepseekV2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
@@ -669,6 +670,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=1,
layer_id=layer_id,
v_head_dim=self.kv_lora_rank,
quant_config=quant_config,
prefix=add_prefix("attn_mqa", prefix),
)
@@ -679,6 +681,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
v_head_dim=self.v_head_dim,
quant_config=quant_config,
prefix=add_prefix("attn_mha", prefix),
)

View File

@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
)
def forward(

View File

@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module):
if use_sliding_window
else None
),
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size=self.sliding_window,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module):
scaling=self.scale,
num_kv_heads=total_num_heads,
layer_id=layer_id,
quant_config=quant_config,
)
def forward(

View File

@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
scaling=self.scale,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module):
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
logit_cap=logit_cap,
quant_config=quant_config,
)
def forward(

View File

@@ -145,6 +145,7 @@ class InternLM2Attention(nn.Module):
self.scaling,
self.num_kv_heads,
layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -170,6 +170,7 @@ class LlamaAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -146,6 +146,7 @@ class MiniCPMAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -192,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
@@ -343,6 +344,7 @@ class MiniCPM3AttentionMLA(nn.Module):
num_kv_heads=1,
layer_id=layer_id,
v_head_dim=self.kv_lora_rank,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -169,6 +169,7 @@ class MixtralAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -232,6 +232,7 @@ class MixtralAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -535,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module):
self.num_local_key_value_heads,
layer_id=layer_id,
is_cross_attention=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -93,6 +93,7 @@ class OlmoAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -118,6 +118,7 @@ class Olmo2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -170,6 +170,7 @@ class OlmoeAttention(nn.Module):
self.scaling,
layer_id=layer_id,
num_kv_heads=self.num_kv_heads,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -202,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
self.scale,
num_kv_heads=self.num_kv_heads_per_partion,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -133,6 +133,7 @@ class QWenAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -154,6 +154,7 @@ class Qwen2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -231,6 +231,7 @@ class Qwen2MoeAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -149,6 +149,7 @@ class StablelmAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_key_value_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -153,6 +153,7 @@ class XverseAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -252,6 +252,7 @@ class XverseAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)

View File

@@ -37,11 +37,6 @@ DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST = (
"nvidia/Llama-3.1-8B-Instruct-FP8"
)
# TODO(yundai424): right now specifying to an older revision since the latest one
# carries kv cache quantization which doesn't work yet
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION = (
"13858565416dbdc0b4e7a4a677fadfbd5b9e5bb9"
)
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"