diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index c31c2e0b5..732395765 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -403,12 +403,12 @@ def _yarn_find_correction_range( def _yarn_linear_ramp_mask( - low: float, high: float, dim: int, dtype: torch.dtype + low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device = None ) -> torch.Tensor: if low == high: high += 0.001 # Prevent singularity - linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func @@ -688,7 +688,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = ( 1 - - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device + ) ) * self.extrapolation_factor inv_freq = ( inv_freq_interpolation * (1 - inv_freq_mask)