[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

View File

@@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
from typing import Any
import torch
from vllm.logger import init_logger
import vllm.model_executor.layers.rotary_embedding as rotary_embedding
from vllm.model_executor.layers.rotary_embedding import (
_ROPE_DICT,
RotaryEmbedding,
)
from vllm.model_executor.layers.rotary_embedding import (
_ROPE_DICT,
DualChunkRotaryEmbedding,
DynamicNTKAlphaRotaryEmbedding,
DynamicNTKScalingRotaryEmbedding,
Llama4VisionRotaryEmbedding,
MRotaryEmbedding,
NTKScalingRotaryEmbedding,
Phi3LongRoPEScaledRotaryEmbedding,
YaRNScalingRotaryEmbedding,
)
from .base import MLURotaryEmbedding
from .deepseek_scaling_rope import MLUDeepseekScalingRotaryEmbedding
from .dynamic_ntk_alpha_rope import MLUDynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import MLUDynamicNTKScalingRotaryEmbedding
from .linear_scaling_rope import MLULinearScalingRotaryEmbedding
from .llama3_rope import MLULlama3RotaryEmbedding
from .mrope import MLUMRotaryEmbedding
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
def get_long_max_model_max_position_emb(max_position_embeddings, scaling_factor):
if MLURotaryEmbedding.max_seq_len != None and \
MLURotaryEmbedding.max_seq_len > max_position_embeddings * scaling_factor:
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) * scaling_factor ({scaling_factor}) " +
"from model's config.json, This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
return math.ceil(MLURotaryEmbedding.max_seq_len / scaling_factor)
return max_position_embeddings
def vllm__model_executor__layers__rotary_embedding__get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
is_neox_style: bool = True,
rope_scaling: dict[str, Any] | None = None,
dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: dict[str, Any] | None = None,
inverse: bool = False
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling_args,
dual_chunk_attention_args,
dtype,
inverse,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
**extra_kwargs,
)
elif not rope_scaling:
rotary_emb = MLURotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype,
inverse=inverse,
)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
rotary_emb = MLULlama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MLUMRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
)
else:
rotary_emb = MLURotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
inverse=inverse,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = MLULinearScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get('mixed_b', None)
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_scaling:
scaling_alpha = rope_scaling["alpha"]
rotary_emb = MLUDynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_alpha,
dtype,
)
elif "factor" in rope_scaling:
scaling_factor = rope_scaling["factor"]
rotary_emb = MLUDynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
)
}
if "mrope_section" in rope_scaling:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: update original_max_position
'''
original_max_position = get_long_max_model_max_position_emb(
original_max_position, scaling_factor,
)
'''
==================
End of MLU Hijack
==================
'''
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
'''
=============================
Modify by vllm_mlu
=============================
@brief: update original_max_position
'''
original_max_position = get_long_max_model_max_position_emb(
original_max_position, scaling_factor,
)
'''
==================
End of MLU Hijack
==================
'''
rotary_emb = MLUDeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
inverse,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
MluHijackObject.apply_hijack(
rotary_embedding,
rotary_embedding.get_rope,
vllm__model_executor__layers__rotary_embedding__get_rope,
)

View File

@@ -0,0 +1,302 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.utils import (
get_common_metadata,
MLUCommonAttentionMetadata,
)
from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
logger = init_logger(__name__)
@CustomOp.register("rotary_embedding_mlu")
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
cu_seq_lens : torch.Tensor = None
max_seq_len : int = None
max_model_len : int = None
is_prompt : bool = False
is_chunked : bool = False
positions_: torch.Tensor = None
chunked_prefill_enabled: bool = False
prefill_cu_seq_lens: torch.Tensor = None
prefill_max_seq_len: int = None
decode_cu_seq_lens: torch.Tensor = None
decode_max_seq_len: int = None
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
inverse: bool = False,
) -> None:
CustomOp.__init__(self)
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
# TODO(mgoin): disabled for now due to failures
# Flashinfer only supports head_size=64, 128, 256, 512.
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
# self.use_flashinfer = (self.enabled()
# and dtype in (torch.float16, torch.bfloat16)
# and current_platform.is_cuda()
# and has_flashinfer()
# and self.head_size in [64, 128, 256, 512])
self.use_flashinfer = False
self.inverse = inverse
# For vlm v1
# 1. mlu rope run in eager mode
# 2. all layer use layer0's rope to inference
prefix = "global_rope"
vllm_config = get_current_vllm_config()
self.use_direct_call = False
if not self.use_direct_call:
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
pass
else:
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding
if MLURotaryEmbedding.max_seq_len != None \
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
f"This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
cache = self._compute_cos_sin_cache()
from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding
if isinstance(self, MLULinearScalingRotaryEmbedding):
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
elif is_neox_style:
cache_pos = cache.shape[0]
cache = cache.reshape(cache_pos, 2, -1)
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
else:
cache = cache.repeat_interleave(2, dim=-1)
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.cos_, self.sin_ = self._get_cos_sin()
@classmethod
def set_mlu_var_v1(
cls,
common_metadata: MLUCommonAttentionMetadata
) -> None:
cls.unset_mlu_var()
cls.cu_seq_lens = common_metadata.query_start_loc
cls.max_seq_len = common_metadata.max_query_len
cls.is_prompt = common_metadata.is_prefill_only
cls.is_chunked = common_metadata.is_chunked
# for MLA
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
_, attn_metadata = next(iter(attn_metadata.items()))
if isinstance(attn_metadata, MLACommonMetadata):
prefill_metadata = attn_metadata.prefill
decode_metadata = attn_metadata.decode
if prefill_metadata:
cls.prefill_max_seq_len = prefill_metadata.max_query_len
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
else:
cls.prefill_max_seq_len = cls.max_seq_len
cls.prefill_cu_seq_lens = cls.cu_seq_lens
if decode_metadata:
cls.decode_max_seq_len = decode_metadata.max_query_len
cls.decode_cu_seq_lens = decode_metadata.query_start_loc
else:
cls.decode_max_seq_len = cls.max_seq_len
cls.decode_cu_seq_lens = cls.cu_seq_lens
# for sp
sp_context = get_sp_forward_context()
if sp_context is not None and sp_context.is_v32:
prefill_metadata = sp_context.sp_attn_metadata.prefill
cls.is_chunked = True
cls.prefill_max_seq_len = prefill_metadata.max_query_len
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
@classmethod
def unset_mlu_var(cls):
cls.cu_seq_lens = None
cls.max_seq_len = None
cls.is_prompt = False
cls.is_chunked = False
cls.positions_ = None
cls.chunked_prefill_enabled = False
cls.prefill_cu_seq_lens = None
cls.prefill_max_seq_len = None
cls.decode_cu_seq_lens = None
cls.decode_max_seq_len = None
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
sin = sin.view(-1, self.rotary_dim)
cos = cos.view(-1, self.rotary_dim)
return cos, sin
def _get_positions_with_offsets_mlu(
self,
positions: torch.Tensor,
offsets: torch.Tensor
) -> torch.Tensor:
if offsets.numel() != positions.numel():
raise Exception("rope offsets numel mismatch with positions, "
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
return (positions + offsets).to(torch.int32)
def forward_impl(
self,
positions: torch.Tensor,
x: torch.Tensor,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
if common_metadata is None:
num_tokens, head_num, head_size = x.shape
x = mlu_ops.rotary_embedding(
x.view(1, num_tokens, head_num, head_size),
self.sin_,
self.cos_,
positions,
None,
not self.is_neox_style,
True,
False,
num_tokens
)
return x
else:
cu_seq_lens_ = common_metadata.query_start_loc
if offsets is not None:
if MLURotaryEmbedding.positions_ is None:
MLURotaryEmbedding.positions_ = (
self._get_positions_with_offsets_mlu(positions, offsets))
position_ids = MLURotaryEmbedding.positions_
discrete = True
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
x = mlu_ops.rotary_embedding(
x,
self.sin_,
self.cos_,
position_ids,
cu_seq_lens_,
not self.is_neox_style,
discrete,
False,
MLURotaryEmbedding.max_seq_len
)
return x
def get_param(self, positions, discrete=False):
interleaved = True
if self.is_neox_style:
interleaved = False
if discrete:
position_ids = positions
discrete = discrete
else:
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
return position_ids, interleaved, discrete
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.outer(t, inv_freq)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cos = freqs_cis.real
sin = freqs_cis.imag * (-1 if self.inverse else 1)
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor | None = None,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
only_prefill: bool | None = False,
only_decode: bool | None = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.forward_impl(positions, query, offsets)
if key is not None:
self.forward_impl(positions, key, offsets)
return query, key
def rope_forward(
positions: torch.Tensor,
x: torch.Tensor,
layer_name: str,
offsets: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_impl(positions, x, offsets)
def rope_forward_fake(
positions: torch.Tensor,
x: torch.Tensor,
layer_name: str,
offsets: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="rope_forward",
op_func=rope_forward,
mutates_args=["x"],
fake_impl=rope_forward_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
DeepseekScalingRotaryEmbedding,
yarn_get_mscale,
)
from vllm.model_executor.layers.rotary_embedding.common import (
rotate_gptj,
rotate_neox,
yarn_find_correction_range,
yarn_linear_ramp_mask,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDeepseekScalingRotaryEmbedding(MLURotaryEmbedding, DeepseekScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
inverse: bool = False,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
self.inverse = inverse
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def forward_mlu_rot(self, input, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len):
"""only one input rotary implementation"""
if input is None:
return None
if self.rotary_dim < self.head_size:
input_pass = input[..., self.rotary_dim:]
input_rot = input[..., :self.rotary_dim]
input_rot = mlu_ops.rotary_embedding(
input_rot,
self.sin_,
self.cos_,
position_ids,
cu_seq_lens,
interleaved,
discrete,
False,
max_seq_len
)
if self.rotary_dim < self.head_size:
input = torch.cat((input_rot, input_pass), dim=-1)
else:
input = input_rot
return input
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor | None = None,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
only_prefill: bool | None = False,
only_decode: bool | None = False,
discrete: bool | None = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
position_ids, interleaved, discrete = self.get_param(positions, discrete)
cu_seq_lens = MLURotaryEmbedding.cu_seq_lens
max_seq_len = MLURotaryEmbedding.max_seq_len
# for MLA
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
_, attn_metadata = next(iter(attn_metadata.items()))
if isinstance(attn_metadata, MLACommonMetadata):
if only_prefill:
cu_seq_lens = MLURotaryEmbedding.prefill_cu_seq_lens
max_seq_len = MLURotaryEmbedding.prefill_max_seq_len
elif only_decode:
cu_seq_lens = MLURotaryEmbedding.decode_cu_seq_lens
max_seq_len = MLURotaryEmbedding.decode_max_seq_len
query = self.forward_mlu_rot(query, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
key = self.forward_mlu_rot(key, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
return query, key
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(
0,
self.rotary_dim,
2,
dtype=torch.float,
device=current_platform.device_type,
)
/ self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
device = current_platform.device_type
inv_freq_mask = ((
1
- yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
) * self.extrapolation_factor).to(device)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32,
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale * (-1 if self.inverse else 1)
cache = torch.cat((cos, sin), dim=-1)
return cache
forward = MLURotaryEmbedding.forward
forward_native = forward_oot

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDynamicNTKAlphaRotaryEmbedding(MLURotaryEmbedding, DynamicNTKAlphaRotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUDynamicNTKScalingRotaryEmbedding(MLURotaryEmbedding, DynamicNTKScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
) -> None:
self.scaling_factor = scaling_factor
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Union
import torch
from vllm.platforms import current_platform
from vllm.model_executor.layers.rotary_embedding.linear_scaling_rope import LinearScalingRotaryEmbedding
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLULinearScalingRotaryEmbedding(MLURotaryEmbedding, LinearScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
scaling_factors: list[float] | float,
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: list[float] = scaling_factors # noqa
MLURotaryEmbedding.__init__(
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
# Lazy initialized.
self._scaling_factor_to_offset: dict[float, int]
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
device = current_platform.device_type
if self.is_neox_style:
half_dim = self.rotary_dim // 2
inv_freq = 1.0 / (
base
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
% half_dim * 2 / self.rotary_dim)
)
else:
inv_freq = 1.0 / (
base
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
// 2 * 2 / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list: list[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: list[int] = []
device = current_platform.device_type
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float, device=device)
t = t / scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache)
self._scaling_factor_to_offset = {
float(scaling_factor): offsets[i]
for i, scaling_factor in enumerate(self.scaling_factors)
}
assert len(self.scaling_factors) == len(offsets)
return torch.cat(cache_list, dim=0)

View File

@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLULlama3RotaryEmbedding(MLURotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

View File

@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.model_executor.layers.rotary_embedding.common import yarn_get_mscale
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: list[int] | None = None,
mrope_interleaved: bool = False,
# YaRN parameters.
*,
scaling_factor: float | None = None,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
if self.scaling_factor is not None:
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
else:
self.mscale = 1.0
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self.cache_max_position_num = max_position_embeddings * 4
MLURotaryEmbedding.__init__(
self,
head_size,
rotary_dim,
self.cache_max_position_num,
base,
is_neox_style,
dtype,
)
self.mrope_section = mrope_section
self.mrope_interleaved = mrope_interleaved
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def _apply_mrope(self, positions):
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
num_section = len(self.mrope_section)
mrope_section = self.mrope_section * 2
def _apply(x):
x = torch.cat([
m[i % num_section]
for i, m in enumerate(x.split(mrope_section, dim=-1))
],
dim=-1)
return x
return _apply(cos), _apply(sin)
def _apply_interleaved_mrope(self, positions):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
mrope_section = self.mrope_section
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
def _apply(x):
x_t = x[0].clone()
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
offset = self.rotary_dim // 2
x_t[..., 1 + offset:mrope_section[1] * 3 + offset:3] = x[1, ..., 1 + offset:mrope_section[1] * 3 + offset:3]
x_t[..., 2 + offset:mrope_section[2] * 3 + offset:3] = x[2, ..., 2 + offset:mrope_section[2] * 3 + offset:3]
return x_t
return _apply(cos), _apply(sin)
def precompute_sin_cos_cache(
self,
positions: torch.Tensor
):
'''
call this function before forward decoder layers
precompute sin/cos cache for mrope
'''
if positions.ndim == 1:
return
assert positions.ndim == 2
assert self.mrope_section
if self.mrope_interleaved:
cos, sin = self._apply_interleaved_mrope(positions)
else:
cos, sin = self._apply_mrope(positions)
self.mrope_cos_cache = cos
self.mrope_sin_cache = sin
self.mrope_cu_seq_lens = torch.zeros(2, dtype=torch.int32, device=positions.device)
num_tokens = positions.shape[-1]
self.mrope_cu_seq_lens[1] = num_tokens
def forward_oot(
self,
positions: torch.Tensor,
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
assert positions.ndim == 1 or positions.ndim == 2
if positions.ndim == 1:
return MLURotaryEmbedding.forward_oot(self, positions, x)
assert self.mrope_cos_cache is not None and self.mrope_sin_cache is not None,\
"call precompute_sin_cos_cache first!"
num_tokens = positions.shape[-1]
x = mlu_ops.rotary_embedding(x,
self.mrope_sin_cache,
self.mrope_cos_cache,
None,
self.mrope_cu_seq_lens,
not self.is_neox_style,
False,
False,
num_tokens)
return x
forward = MLURotaryEmbedding.forward