[perf] dsv3 bmm fallback to bf16 (#5662)

This commit is contained in:
JieXin Liang
2025-05-09 02:43:39 +08:00
committed by GitHub
parent fa7d7fd9e5
commit 5e02330137
2 changed files with 48 additions and 3 deletions

View File

@@ -63,6 +63,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_dequant,
block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
normalize_e4m3fn_to_e4m3fnuz,
@@ -1589,13 +1590,22 @@ class DeepseekV2ForCausalLM(nn.Module):
if (
_is_cuda
and _ENABLE_JIT_DEEPGEMM
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
block_scale = weight_scale
use_deep_gemm_bmm = True
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
"SGL_USE_DEEPGEMM_BMM", "false"
):
block_scale = weight_scale
use_deep_gemm_bmm = True
else:
w = block_quant_dequant(
weight,
weight_scale,
weight_block_size,
model_dtype,
)
else:
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size