[perf] introduce deep gemm group_gemm_masked as bmm (#5432)
This commit is contained in:
@@ -57,7 +57,11 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_enable_jit_deepgemm_bmm,
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||
per_tensor_quant_mla_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
channel_quant_to_tensor_quant,
|
||||
@@ -82,6 +86,7 @@ _is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
else:
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
@@ -530,6 +535,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_vc = None
|
||||
self.w_scale = None
|
||||
|
||||
self.w_scale_k = None
|
||||
self.w_scale_v = None
|
||||
self.use_deep_gemm_bmm = False
|
||||
|
||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||
"flashinfer_mla_disable_ragged"
|
||||
]
|
||||
@@ -684,7 +693,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||
if self.use_deep_gemm_bmm:
|
||||
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
)
|
||||
)
|
||||
q_nope_out = q_nope.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
(q_nope_val, q_nope_scale),
|
||||
(self.w_kc, self.w_scale_k),
|
||||
q_nope_out,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
q_nope_out = q_nope_out[:, :expected_m, :]
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
@@ -716,7 +742,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
||||
if self.use_deep_gemm_bmm:
|
||||
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8(
|
||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
)
|
||||
)
|
||||
attn_bmm_output = attn_output.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.v_head_dim)
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
(attn_output_val, attn_output_scale),
|
||||
(self.w_vc, self.w_scale_v),
|
||||
attn_bmm_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
attn_bmm_output = torch.bmm(
|
||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||
@@ -1439,6 +1482,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
w = self_attn.kv_b_proj.weight
|
||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||
# This may affect the accuracy of fp8 model.
|
||||
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
||||
use_deep_gemm_bmm = False
|
||||
model_dtype = torch.get_default_dtype()
|
||||
|
||||
if w.dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
@@ -1457,10 +1504,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||
|
||||
w, scale = block_quant_to_tensor_quant(
|
||||
weight, weight_scale, weight_block_size
|
||||
)
|
||||
self_attn.w_scale = scale
|
||||
if (
|
||||
_is_cuda
|
||||
and _enable_jit_deepgemm_bmm
|
||||
and weight_block_size[0] == 128
|
||||
and weight_block_size[1] == 128
|
||||
and model_dtype == torch.bfloat16
|
||||
):
|
||||
block_scale = weight_scale
|
||||
use_deep_gemm_bmm = True
|
||||
else:
|
||||
w, scale = block_quant_to_tensor_quant(
|
||||
weight, weight_scale, weight_block_size
|
||||
)
|
||||
self_attn.w_scale = scale
|
||||
else:
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale
|
||||
@@ -1483,18 +1540,31 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
|
||||
torch.bfloat16
|
||||
)
|
||||
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
if (
|
||||
hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||
and self_attn.w_scale is None
|
||||
):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
if _is_hip:
|
||||
self_attn.w_scale *= 2.0
|
||||
if not use_deep_gemm_bmm:
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
||||
if (
|
||||
hasattr(self_attn.kv_b_proj, "weight_scale")
|
||||
and self_attn.w_scale is None
|
||||
):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
if _is_hip:
|
||||
self_attn.w_scale *= 2.0
|
||||
else:
|
||||
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
||||
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
||||
ws_kc, ws_vc = block_scale.unflatten(
|
||||
0, (-1, (num_tiles_k + num_tiles_n))
|
||||
).split([num_tiles_k, num_tiles_n], dim=1)
|
||||
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
|
||||
self_attn.w_scale_v = ws_vc.contiguous()
|
||||
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
|
||||
self_attn.w_vc = w_vc.contiguous()
|
||||
self_attn.use_deep_gemm_bmm = True
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
Reference in New Issue
Block a user