Revert "[XPU][CPU] Enable the native path of DeepSeek" (#4367)

This commit is contained in:
Lianmin Zheng
2025-03-12 23:45:52 -07:00
committed by GitHub
parent 71046fcd71
commit 45de89719c
16 changed files with 221 additions and 499 deletions

View File

@@ -69,7 +69,6 @@ class RotaryEmbedding(CustomOp):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
) -> None:
super().__init__()
self.head_size = head_size
@@ -78,7 +77,6 @@ class RotaryEmbedding(CustomOp):
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
self.device = device
cache = self._compute_cos_sin_cache()
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
@@ -285,19 +283,12 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factors: Union[List[float], float],
dtype: torch.dtype,
device: str,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: List[float] = scaling_factors # noqa
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
# Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int]
@@ -356,17 +347,10 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_cos_sin_cache(self) -> torch.Tensor:
@@ -450,7 +434,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
@@ -465,13 +448,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -668,7 +645,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
@@ -676,6 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
device: Optional[str] = "cuda",
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
@@ -688,14 +665,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
self.device = device
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
@@ -790,7 +762,6 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
@@ -801,13 +772,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
str,
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
@@ -845,17 +810,10 @@ class MRotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
self.mrope_section = mrope_section
@@ -1045,14 +1003,9 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: str = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
from sglang.srt.managers.schedule_batch import global_server_args_dict
device = global_server_args_dict["device"]
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
@@ -1077,7 +1030,7 @@ def get_rope(
if rope_scaling is None:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype, device
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
if "rope_type" in rope_scaling:
@@ -1099,7 +1052,6 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
scaling_factor,
low_freq_factor,
high_freq_factor,
@@ -1114,7 +1066,6 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
mrope_section=rope_scaling["mrope_section"],
)
else:
@@ -1125,7 +1076,6 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
@@ -1137,7 +1087,6 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
@@ -1149,7 +1098,6 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
@@ -1168,7 +1116,6 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
@@ -1196,7 +1143,6 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
**extra_kwargs,
)
elif scaling_type == "longrope":
@@ -1307,8 +1253,21 @@ def get_rope_wrapper(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
):
return get_rope(
if device != "cpu":
return get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
)
return get_rope_cpu(
head_size,
rotary_dim,
max_position,
@@ -1317,4 +1276,5 @@ def get_rope_wrapper(
rope_scaling,
dtype,
partial_rotary_factor,
device,
)