fix deepseek v2 with cpu device (#2975)
This commit is contained in:
@@ -664,6 +664,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
@@ -676,13 +677,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||
* attn_factor
|
||||
)
|
||||
self.device = device
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base ** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda")
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
|
||||
/ self.rotary_dim
|
||||
)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
@@ -710,7 +712,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(
|
||||
self.max_position_embeddings * self.scaling_factor,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
@@ -1174,3 +1176,111 @@ def get_rope(
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
|
||||
def get_rope_cpu(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: Optional[str] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling_args,
|
||||
dtype,
|
||||
)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
assert rope_scaling is not None
|
||||
scaling_type = rope_scaling["rope_type"]
|
||||
assert (
|
||||
scaling_type == "deepseek_yarn"
|
||||
), "Only deepseek_yarn is supported for CPU for now"
|
||||
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
)
|
||||
}
|
||||
extra_kwargs["device"] = device
|
||||
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
|
||||
def get_rope_wrapper(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
device: Optional[str] = None,
|
||||
):
|
||||
if device != "cpu":
|
||||
return get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
)
|
||||
|
||||
return get_rope_cpu(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
partial_rotary_factor,
|
||||
device,
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -271,7 +271,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
self.rotary_emb = get_rope(
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
|
||||
@@ -39,12 +39,12 @@ from PIL import Image
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.linear import (
|
||||
|
||||
0
python/sglang/srt/models/olmo2.py
Executable file → Normal file
0
python/sglang/srt/models/olmo2.py
Executable file → Normal file
Reference in New Issue
Block a user