From 5a176c92dfa13183deca012fe4c43d9d75815390 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 19 Jan 2025 21:33:27 +0800 Subject: [PATCH] fix deepseek v2 with cpu device (#2975) --- python/sglang/srt/layers/rotary_embedding.py | 114 ++++++++++++++++++- python/sglang/srt/models/deepseek_v2.py | 4 +- python/sglang/srt/models/minicpmv.py | 2 +- python/sglang/srt/models/olmo2.py | 0 4 files changed, 115 insertions(+), 5 deletions(-) mode change 100755 => 100644 python/sglang/srt/models/olmo2.py diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7c18c683e..bc38fa8c0 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -664,6 +664,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, + device: Optional[str] = "cuda", ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -676,13 +677,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor ) + self.device = device super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda") + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) / self.rotary_dim ) inv_freq_extrapolation = 1.0 / pos_freqs @@ -710,7 +712,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange( self.max_position_embeddings * self.scaling_factor, - device="cuda", + device=self.device, dtype=torch.float32, ) freqs = torch.einsum("i,j -> ij", t, inv_freq) @@ -1174,3 +1176,111 @@ def get_rope( raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb return rotary_emb + + +def get_rope_cpu( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + device: Optional[str] = None, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + assert rope_scaling is not None + scaling_type = rope_scaling["rope_type"] + assert ( + scaling_type == "deepseek_yarn" + ), "Only deepseek_yarn is supported for CPU for now" + + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + extra_kwargs["device"] = device + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + + _ROPE_DICT[key] = rotary_emb + return rotary_emb + + +def get_rope_wrapper( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + device: Optional[str] = None, +): + if device != "cpu": + return get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + ) + + return get_rope_cpu( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + device, + ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0d327c0ca..17d7fcf89 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -48,7 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -271,7 +271,7 @@ class DeepseekV2Attention(nn.Module): quant_config=quant_config, ) rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope( + self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 5ff941b6c..23147529a 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -39,12 +39,12 @@ from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ from transformers import PretrainedConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata +from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ( diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py old mode 100755 new mode 100644