[XPU][CPU] Enable the native path of DeepSeek (#4086)

Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
Meng, Hengyu
2025-03-13 13:26:29 +08:00
committed by GitHub
parent c76040e31b
commit 71046fcd71
16 changed files with 501 additions and 223 deletions

View File

@@ -69,6 +69,7 @@ class RotaryEmbedding(CustomOp):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
) -> None:
super().__init__()
self.head_size = head_size
@@ -77,6 +78,7 @@ class RotaryEmbedding(CustomOp):
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
self.device = device
cache = self._compute_cos_sin_cache()
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
@@ -283,12 +285,19 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factors: Union[List[float], float],
dtype: torch.dtype,
device: str,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: List[float] = scaling_factors # noqa
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
# Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int]
@@ -347,10 +356,17 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
def _compute_cos_sin_cache(self) -> torch.Tensor:
@@ -434,6 +450,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
@@ -448,7 +465,13 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -645,6 +668,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
@@ -652,7 +676,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
device: Optional[str] = "cuda",
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
@@ -665,9 +688,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
self.device = device
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -762,6 +790,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
@@ -772,7 +801,13 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
str,
)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
@@ -810,10 +845,17 @@ class MRotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
self.mrope_section = mrope_section
@@ -1003,9 +1045,14 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: str = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
from sglang.srt.managers.schedule_batch import global_server_args_dict
device = global_server_args_dict["device"]
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
@@ -1030,7 +1077,7 @@ def get_rope(
if rope_scaling is None:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
head_size, rotary_dim, max_position, base, is_neox_style, dtype, device
)
else:
if "rope_type" in rope_scaling:
@@ -1052,6 +1099,7 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
scaling_factor,
low_freq_factor,
high_freq_factor,
@@ -1066,6 +1114,7 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
mrope_section=rope_scaling["mrope_section"],
)
else:
@@ -1076,6 +1125,7 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
@@ -1087,6 +1137,7 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
@@ -1098,6 +1149,7 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
@@ -1116,6 +1168,7 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
@@ -1143,6 +1196,7 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
**extra_kwargs,
)
elif scaling_type == "longrope":
@@ -1253,21 +1307,8 @@ def get_rope_wrapper(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
):
if device != "cpu":
return get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
)
return get_rope_cpu(
return get_rope(
head_size,
rotary_dim,
max_position,
@@ -1276,5 +1317,4 @@ def get_rope_wrapper(
rope_scaling,
dtype,
partial_rotary_factor,
device,
)