[XPU][CPU] Enable the native path of DeepSeek (#4086)
Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import (
|
||||
block_dequant as int8_block_dequant,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -305,7 +305,6 @@ class DeepseekV2Attention(nn.Module):
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
@@ -501,7 +500,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
@@ -646,19 +645,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
||||
if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available():
|
||||
q_nope_val, q_nope_scale = input_to_float8(
|
||||
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
)
|
||||
|
||||
else:
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
@@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
||||
if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available():
|
||||
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
||||
attn_bmm_output = torch.bmm(
|
||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available():
|
||||
attn_output_val, attn_output_scale = input_to_float8(
|
||||
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
||||
)
|
||||
@@ -694,6 +694,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_scale,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
else:
|
||||
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
||||
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||
|
||||
Reference in New Issue
Block a user