Refactor DeepGEMM integration (#7150)

This commit is contained in:
fzyzcjy
2025-06-14 11:41:03 +08:00
committed by GitHub
parent 8b8f2e7463
commit b4c41f7276
12 changed files with 207 additions and 147 deletions

View File

@@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
per_tensor_quant_mla_fp8,
@@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.quantization.deep_gemm import (
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
)
else:
from vllm._custom_ops import awq_dequantize
@@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
@@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module):
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and get_bool_env_var(
"SGL_USE_DEEPGEMM_BMM", "false"
):
block_scale = weight_scale