[XPU][CPU] Enable the native path of DeepSeek (#4086)

Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
Meng, Hengyu
2025-03-13 13:26:29 +08:00
committed by GitHub
parent c76040e31b
commit 71046fcd71
16 changed files with 501 additions and 223 deletions

View File

@@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
@@ -305,7 +305,6 @@ class DeepseekV2Attention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
if rope_scaling:
@@ -501,7 +500,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
@@ -646,19 +645,20 @@ class DeepseekV2AttentionMLA(nn.Module):
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.w_kc.dtype == torch.float8_e4m3fnuz:
if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available():
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
@@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fnuz:
if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available():
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_vc.dtype == torch.float8_e4m3fn:
elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available():
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
@@ -694,6 +694,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale,
torch.bfloat16,
)
else:
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)