[feature] enable pre compile jit deep_gemm (#5580)

This commit is contained in:
JieXin Liang
2025-04-22 07:52:53 +08:00
committed by GitHub
parent e69a219074
commit c2942907d5
7 changed files with 549 additions and 45 deletions

View File

@@ -57,8 +57,8 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
_enable_jit_deepgemm_bmm,
per_tensor_quant_mla_deep_gemm_masked_fp8,
per_tensor_quant_mla_fp8,
)
@@ -86,8 +86,11 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.quantization.deep_gemm import (
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
)
else:
from vllm._custom_ops import awq_dequantize
@@ -702,7 +705,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
@@ -751,7 +754,7 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_bmm_output = attn_output.new_empty(
(self.num_local_heads, aligned_m, self.v_head_dim)
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
(attn_output_val, attn_output_scale),
(self.w_vc, self.w_scale_v),
attn_bmm_output,
@@ -1520,7 +1523,7 @@ class DeepseekV2ForCausalLM(nn.Module):
if (
_is_cuda
and _enable_jit_deepgemm_bmm
and _ENABLE_JIT_DEEPGEMM
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16