Revert "[XPU][CPU] Enable the native path of DeepSeek" (#4367)
This commit is contained in:
@@ -69,7 +69,6 @@ class RotaryEmbedding(CustomOp):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
@@ -78,7 +77,6 @@ class RotaryEmbedding(CustomOp):
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
||||
@@ -285,19 +283,12 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factors: Union[List[float], float],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors: List[float] = scaling_factors # noqa
|
||||
super().__init__(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
# Lazy initialized.
|
||||
self._scaling_factor_to_offset: Dict[float, int]
|
||||
@@ -356,17 +347,10 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
@@ -450,7 +434,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
@@ -465,13 +448,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
@@ -668,7 +645,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
@@ -676,6 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
@@ -688,14 +665,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||
* attn_factor
|
||||
)
|
||||
self.device = device
|
||||
super().__init__(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
@@ -790,7 +762,6 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
scaling_factor: float,
|
||||
low_freq_factor: float,
|
||||
high_freq_factor: float,
|
||||
@@ -801,13 +772,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
self.high_freq_factor = high_freq_factor
|
||||
self.orig_max_position = orig_max_position
|
||||
super().__init__(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
str,
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
@@ -845,17 +810,10 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
mrope_section: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
@@ -1045,14 +1003,9 @@ def get_rope(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: str = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
device = global_server_args_dict["device"]
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
@@ -1077,7 +1030,7 @@ def get_rope(
|
||||
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype, device
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
)
|
||||
else:
|
||||
if "rope_type" in rope_scaling:
|
||||
@@ -1099,7 +1052,6 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
scaling_factor,
|
||||
low_freq_factor,
|
||||
high_freq_factor,
|
||||
@@ -1114,7 +1066,6 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
)
|
||||
else:
|
||||
@@ -1125,7 +1076,6 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1137,7 +1087,6 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1149,7 +1098,6 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1168,7 +1116,6 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
@@ -1196,7 +1143,6 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "longrope":
|
||||
@@ -1307,8 +1253,21 @@ def get_rope_wrapper(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
return get_rope(
|
||||
if device != "cpu":
|
||||
return get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
)
|
||||
|
||||
return get_rope_cpu(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
@@ -1317,4 +1276,5 @@ def get_rope_wrapper(
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user