[XPU][CPU] Enable the native path of DeepSeek (#4086)
Co-authored-by: Zhang, Liangang <liangang.zhang@intel.com>
This commit is contained in:
@@ -69,6 +69,7 @@ class RotaryEmbedding(CustomOp):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
@@ -77,6 +78,7 @@ class RotaryEmbedding(CustomOp):
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
||||
@@ -283,12 +285,19 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factors: Union[List[float], float],
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors: List[float] = scaling_factors # noqa
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
# Lazy initialized.
|
||||
self._scaling_factor_to_offset: Dict[float, int]
|
||||
@@ -347,10 +356,17 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
@@ -434,6 +450,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
@@ -448,7 +465,13 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
@@ -645,6 +668,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
@@ -652,7 +676,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
@@ -665,9 +688,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||
* attn_factor
|
||||
)
|
||||
self.device = device
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
@@ -762,6 +790,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
scaling_factor: float,
|
||||
low_freq_factor: float,
|
||||
high_freq_factor: float,
|
||||
@@ -772,7 +801,13 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
self.high_freq_factor = high_freq_factor
|
||||
self.orig_max_position = orig_max_position
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
str,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
@@ -810,10 +845,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
mrope_section: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
@@ -1003,9 +1045,14 @@ def get_rope(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: str = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
device = global_server_args_dict["device"]
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
@@ -1030,7 +1077,7 @@ def get_rope(
|
||||
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype, device
|
||||
)
|
||||
else:
|
||||
if "rope_type" in rope_scaling:
|
||||
@@ -1052,6 +1099,7 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
scaling_factor,
|
||||
low_freq_factor,
|
||||
high_freq_factor,
|
||||
@@ -1066,6 +1114,7 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
)
|
||||
else:
|
||||
@@ -1076,6 +1125,7 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1087,6 +1137,7 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1098,6 +1149,7 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@@ -1116,6 +1168,7 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
@@ -1143,6 +1196,7 @@ def get_rope(
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
device,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "longrope":
|
||||
@@ -1253,21 +1307,8 @@ def get_rope_wrapper(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
if device != "cpu":
|
||||
return get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
)
|
||||
|
||||
return get_rope_cpu(
|
||||
return get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
@@ -1276,5 +1317,4 @@ def get_rope_wrapper(
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user