fix deepseek v2 with cpu device (#2975)
This commit is contained in:
@@ -664,6 +664,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
beta_slow: int = 1,
|
beta_slow: int = 1,
|
||||||
mscale: float = 1,
|
mscale: float = 1,
|
||||||
mscale_all_dim: float = 0,
|
mscale_all_dim: float = 0,
|
||||||
|
device: Optional[str] = "cuda",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.extrapolation_factor = extrapolation_factor
|
self.extrapolation_factor = extrapolation_factor
|
||||||
@@ -676,13 +677,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
|
||||||
* attn_factor
|
* attn_factor
|
||||||
)
|
)
|
||||||
|
self.device = device
|
||||||
super().__init__(
|
super().__init__(
|
||||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||||
pos_freqs = self.base ** (
|
pos_freqs = self.base ** (
|
||||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda")
|
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
|
||||||
/ self.rotary_dim
|
/ self.rotary_dim
|
||||||
)
|
)
|
||||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||||
@@ -710,7 +712,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||||
t = torch.arange(
|
t = torch.arange(
|
||||||
self.max_position_embeddings * self.scaling_factor,
|
self.max_position_embeddings * self.scaling_factor,
|
||||||
device="cuda",
|
device=self.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
@@ -1174,3 +1176,111 @@ def get_rope(
|
|||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
_ROPE_DICT[key] = rotary_emb
|
_ROPE_DICT[key] = rotary_emb
|
||||||
return rotary_emb
|
return rotary_emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_rope_cpu(
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
partial_rotary_factor: float = 1.0,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
) -> RotaryEmbedding:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
if rope_scaling is not None:
|
||||||
|
# Transforms every value that is a list into a tuple for caching calls
|
||||||
|
rope_scaling_tuple = {
|
||||||
|
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
|
||||||
|
}
|
||||||
|
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||||
|
else:
|
||||||
|
rope_scaling_args = None
|
||||||
|
if partial_rotary_factor < 1.0:
|
||||||
|
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||||
|
key = (
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
rope_scaling_args,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
if key in _ROPE_DICT:
|
||||||
|
return _ROPE_DICT[key]
|
||||||
|
|
||||||
|
assert rope_scaling is not None
|
||||||
|
scaling_type = rope_scaling["rope_type"]
|
||||||
|
assert (
|
||||||
|
scaling_type == "deepseek_yarn"
|
||||||
|
), "Only deepseek_yarn is supported for CPU for now"
|
||||||
|
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||||
|
extra_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in rope_scaling.items()
|
||||||
|
if k
|
||||||
|
in (
|
||||||
|
"extrapolation_factor",
|
||||||
|
"attn_factor",
|
||||||
|
"beta_fast",
|
||||||
|
"beta_slow",
|
||||||
|
"mscale",
|
||||||
|
"mscale_all_dim",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
extra_kwargs["device"] = device
|
||||||
|
rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
original_max_position,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
scaling_factor,
|
||||||
|
dtype,
|
||||||
|
**extra_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ROPE_DICT[key] = rotary_emb
|
||||||
|
return rotary_emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_rope_wrapper(
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
partial_rotary_factor: float = 1.0,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if device != "cpu":
|
||||||
|
return get_rope(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
rope_scaling,
|
||||||
|
dtype,
|
||||||
|
partial_rotary_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_rope_cpu(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
rope_scaling,
|
||||||
|
dtype,
|
||||||
|
partial_rotary_factor,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
@@ -271,7 +271,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope_wrapper(
|
||||||
qk_rope_head_dim,
|
qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
|
|||||||
@@ -39,12 +39,12 @@ from PIL import Image
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.init import trunc_normal_
|
from torch.nn.init import trunc_normal_
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
|
||||||
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
|
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
|
||||||
|
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.activation import get_act_fn
|
from sglang.srt.layers.activation import get_act_fn
|
||||||
from sglang.srt.layers.attention.vision import VisionAttention
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
|
|||||||
0
python/sglang/srt/models/olmo2.py
Executable file → Normal file
0
python/sglang/srt/models/olmo2.py
Executable file → Normal file
Reference in New Issue
Block a user