Refactor DeepGEMM integration (#7150)
This commit is contained in:
@@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
per_tensor_quant_mla_fp8,
|
||||
@@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import (
|
||||
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
|
||||
)
|
||||
else:
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
|
||||
@@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||
)
|
||||
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
(q_nope_val, q_nope_scale),
|
||||
(self.w_kc, self.w_scale_k),
|
||||
q_nope_out,
|
||||
@@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
and weight_block_size[1] == 128
|
||||
and model_dtype == torch.bfloat16
|
||||
):
|
||||
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
||||
"SGL_USE_DEEPGEMM_BMM", "false"
|
||||
):
|
||||
block_scale = weight_scale
|
||||
|
||||
Reference in New Issue
Block a user