Support Llama4 fp8 inference (#5194)
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -55,6 +55,7 @@ from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
channel_quant_to_tensor_quant,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
@@ -1411,27 +1412,34 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
w = self_attn.kv_b_proj.weight
|
||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||
# This may affect the accuracy of fp8 model.
|
||||
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
||||
if w.dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||
if hasattr(self.quant_config, "weight_block_size"):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||
|
||||
w, scale = block_quant_to_tensor_quant(
|
||||
weight, weight_scale, weight_block_size
|
||||
)
|
||||
w, scale = block_quant_to_tensor_quant(
|
||||
weight, weight_scale, weight_block_size
|
||||
)
|
||||
self_attn.w_scale = scale
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale
|
||||
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
|
||||
self_attn.w_scale = scale
|
||||
|
||||
if w.dtype == torch.int8:
|
||||
if hasattr(self.quant_config, "weight_block_size"):
|
||||
# block-wise int8 need it
|
||||
|
||||
Reference in New Issue
Block a user