[Fix] fix _yarn_linear_ramp_mask with device parameter (#4337)
This commit is contained in:
@@ -403,12 +403,12 @@ def _yarn_find_correction_range(
|
|||||||
|
|
||||||
|
|
||||||
def _yarn_linear_ramp_mask(
|
def _yarn_linear_ramp_mask(
|
||||||
low: float, high: float, dim: int, dtype: torch.dtype
|
low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if low == high:
|
if low == high:
|
||||||
high += 0.001 # Prevent singularity
|
high += 0.001 # Prevent singularity
|
||||||
|
|
||||||
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low)
|
||||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
return ramp_func
|
return ramp_func
|
||||||
|
|
||||||
@@ -688,7 +688,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
# Get n-d rotational scaling corrected for extrapolation
|
# Get n-d rotational scaling corrected for extrapolation
|
||||||
inv_freq_mask = (
|
inv_freq_mask = (
|
||||||
1
|
1
|
||||||
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
- _yarn_linear_ramp_mask(
|
||||||
|
low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device
|
||||||
|
)
|
||||||
) * self.extrapolation_factor
|
) * self.extrapolation_factor
|
||||||
inv_freq = (
|
inv_freq = (
|
||||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||||
|
|||||||
Reference in New Issue
Block a user