Apply dsv3_fused_a_gemm kernel (#7635)
This commit is contained in:
@@ -96,6 +96,7 @@ from sglang.srt.utils import (
|
|||||||
bind_or_assign,
|
bind_or_assign,
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
get_device_sm,
|
||||||
get_int_env_var,
|
get_int_env_var,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
@@ -112,7 +113,7 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
|||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
from sgl_kernel import awq_dequantize, bmm_fp8, dsv3_fused_a_gemm, merge_state_v2
|
||||||
elif _is_cpu and _is_cpu_amx_available:
|
elif _is_cpu and _is_cpu_amx_available:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@@ -875,6 +876,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_min_latency_fused_a_gemm = (
|
||||||
|
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||||
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
|
||||||
|
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
|
||||||
|
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
|
||||||
|
and is_cuda
|
||||||
|
and get_device_sm() >= 90
|
||||||
|
)
|
||||||
|
|
||||||
self.qkv_proj_with_rope_is_int8 = (
|
self.qkv_proj_with_rope_is_int8 = (
|
||||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
||||||
@@ -1114,7 +1124,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
||||||
|
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
||||||
|
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
||||||
|
q, latent_cache = fused_qkv_a_proj_out.split(
|
||||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||||
)
|
)
|
||||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||||
|
|||||||
Reference in New Issue
Block a user