[perf] introduce deep gemm group_gemm_masked as bmm (#5432)

This commit is contained in:
JieXin Liang
2025-04-20 15:38:27 +08:00
committed by GitHub
parent d07e797ace
commit 99456bcacb
3 changed files with 361 additions and 20 deletions

View File

@@ -57,7 +57,11 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm_bmm,
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
@@ -82,6 +86,7 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
else:
from vllm._custom_ops import awq_dequantize
@@ -530,6 +535,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc = None
self.w_scale = None
self.w_scale_k = None
self.w_scale_v = None
self.use_deep_gemm_bmm = False
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
@@ -684,7 +693,24 @@ class DeepseekV2AttentionMLA(nn.Module):
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.w_kc.dtype == torch.float8_e4m3fnuz:
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
)
)
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
masked_m,
expected_m,
)
q_nope_out = q_nope_out[:, :expected_m, :]
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
@@ -716,7 +742,24 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fnuz:
if self.use_deep_gemm_bmm:
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
per_tensor_quant_mla_deep_gemm_masked_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
)
)
attn_bmm_output = attn_output.new_empty(
(self.num_local_heads, aligned_m, self.v_head_dim)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
(attn_output_val, attn_output_scale),
(self.w_vc, self.w_scale_v),
attn_bmm_output,
masked_m,
expected_m,
)
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
@@ -1439,6 +1482,10 @@ class DeepseekV2ForCausalLM(nn.Module):
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
model_dtype = torch.get_default_dtype()
if w.dtype in (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
@@ -1457,10 +1504,20 @@ class DeepseekV2ForCausalLM(nn.Module):
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
if (
_is_cuda
and _enable_jit_deepgemm_bmm
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
block_scale = weight_scale
use_deep_gemm_bmm = True
else:
w, scale = block_quant_to_tensor_quant(
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
else:
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale
@@ -1483,18 +1540,31 @@ class DeepseekV2ForCausalLM(nn.Module):
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if _is_hip:
self_attn.w_scale *= 2.0
if not use_deep_gemm_bmm:
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if _is_hip:
self_attn.w_scale *= 2.0
else:
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
ws_kc, ws_vc = block_scale.unflatten(
0, (-1, (num_tiles_k + num_tiles_n))
).split([num_tiles_k, num_tiles_n], dim=1)
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
self_attn.w_scale_v = ws_vc.contiguous()
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
self_attn.w_vc = w_vc.contiguous()
self_attn.use_deep_gemm_bmm = True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [