[Fix] fix _yarn_linear_ramp_mask with device parameter (#4337)
This commit is contained in:
@@ -403,12 +403,12 @@ def _yarn_find_correction_range(
|
||||
|
||||
|
||||
def _yarn_linear_ramp_mask(
|
||||
low: float, high: float, dim: int, dtype: torch.dtype
|
||||
low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
if low == high:
|
||||
high += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
||||
linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
@@ -688,7 +688,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (
|
||||
1
|
||||
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
||||
- _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device
|
||||
)
|
||||
) * self.extrapolation_factor
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||
|
||||
Reference in New Issue
Block a user