[feature] enable pre compile jit deep_gemm (#5580)
This commit is contained in:
@@ -57,8 +57,8 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_enable_jit_deepgemm_bmm,
|
||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||
per_tensor_quant_mla_fp8,
|
||||
)
|
||||
@@ -86,8 +86,11 @@ _is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import (
|
||||
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
|
||||
)
|
||||
else:
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
|
||||
@@ -702,7 +705,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
||||
(q_nope_val, q_nope_scale),
|
||||
(self.w_kc, self.w_scale_k),
|
||||
q_nope_out,
|
||||
@@ -751,7 +754,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attn_bmm_output = attn_output.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.v_head_dim)
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
||||
(attn_output_val, attn_output_scale),
|
||||
(self.w_vc, self.w_scale_v),
|
||||
attn_bmm_output,
|
||||
@@ -1520,7 +1523,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
if (
|
||||
_is_cuda
|
||||
and _enable_jit_deepgemm_bmm
|
||||
and _ENABLE_JIT_DEEPGEMM
|
||||
and weight_block_size[0] == 128
|
||||
and weight_block_size[1] == 128
|
||||
and model_dtype == torch.bfloat16
|
||||
|
||||
Reference in New Issue
Block a user