Revert "[XPU][CPU] Enable the native path of DeepSeek" (#4367)

This commit is contained in:
Lianmin Zheng
2025-03-12 23:45:52 -07:00
committed by GitHub
parent 71046fcd71
commit 45de89719c
16 changed files with 221 additions and 499 deletions

View File

@@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
@@ -305,6 +305,7 @@ class DeepseekV2Attention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
if rope_scaling:
@@ -500,7 +501,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper(
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
@@ -645,20 +646,19 @@ class DeepseekV2AttentionMLA(nn.Module):
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only
if self.w_kc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available():
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
@@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available():
if self.w_vc.dtype == torch.float8_e4m3fnuz:
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available():
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
@@ -694,7 +694,6 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale,
torch.bfloat16,
)
else:
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)