[Model] Support DeepSeek-V4
This commit is contained in:
342
vllm_mlu/model_executor/layers/rotary_embedding/__init__.py
Normal file
342
vllm_mlu/model_executor/layers/rotary_embedding/__init__.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
import vllm.model_executor.layers.rotary_embedding as rotary_embedding
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
_ROPE_DICT,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
_ROPE_DICT,
|
||||
DualChunkRotaryEmbedding,
|
||||
DynamicNTKAlphaRotaryEmbedding,
|
||||
DynamicNTKScalingRotaryEmbedding,
|
||||
Llama4VisionRotaryEmbedding,
|
||||
MRotaryEmbedding,
|
||||
NTKScalingRotaryEmbedding,
|
||||
Phi3LongRoPEScaledRotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding,
|
||||
)
|
||||
|
||||
from .base import MLURotaryEmbedding
|
||||
from .deepseek_scaling_rope import MLUDeepseekScalingRotaryEmbedding
|
||||
from .dynamic_ntk_alpha_rope import MLUDynamicNTKAlphaRotaryEmbedding
|
||||
from .dynamic_ntk_scaling_rope import MLUDynamicNTKScalingRotaryEmbedding
|
||||
from .linear_scaling_rope import MLULinearScalingRotaryEmbedding
|
||||
from .llama3_rope import MLULlama3RotaryEmbedding
|
||||
from .mrope import MLUMRotaryEmbedding
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_long_max_model_max_position_emb(max_position_embeddings, scaling_factor):
|
||||
if MLURotaryEmbedding.max_seq_len != None and \
|
||||
MLURotaryEmbedding.max_seq_len > max_position_embeddings * scaling_factor:
|
||||
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
|
||||
f"max_position_embedding ({max_position_embeddings}) * scaling_factor ({scaling_factor}) " +
|
||||
"from model's config.json, This may lead to incorrect model outputs or MLU errors. " +
|
||||
f"Make sure the value is correct and within the model context size. " +
|
||||
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
|
||||
return math.ceil(MLURotaryEmbedding.max_seq_len / scaling_factor)
|
||||
return max_position_embeddings
|
||||
|
||||
def vllm__model_executor__layers__rotary_embedding__get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: float,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: dict[str, Any] | None = None,
|
||||
inverse: bool = False
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
dual_chunk_attention_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k != "sparse_attention_config"
|
||||
}
|
||||
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling_args,
|
||||
dual_chunk_attention_args,
|
||||
dtype,
|
||||
inverse,
|
||||
)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k in ("chunk_size", "local_size")
|
||||
}
|
||||
rotary_emb = DualChunkRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif not rope_scaling:
|
||||
rotary_emb = MLURotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype,
|
||||
inverse=inverse,
|
||||
)
|
||||
else:
|
||||
scaling_type = rope_scaling["rope_type"]
|
||||
|
||||
if scaling_type == "llama3":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
rotary_emb = MLULlama3RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
scaling_factor,
|
||||
low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position,
|
||||
)
|
||||
elif scaling_type == "mllama4":
|
||||
rotary_emb = Llama4VisionRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
)
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_scaling:
|
||||
rotary_emb = MLUMRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
|
||||
)
|
||||
else:
|
||||
rotary_emb = MLURotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
inverse=inverse,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = MLULinearScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "ntk":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mixed_b = rope_scaling.get('mixed_b', None)
|
||||
rotary_emb = NTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
mixed_b,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
if "alpha" in rope_scaling:
|
||||
scaling_alpha = rope_scaling["alpha"]
|
||||
rotary_emb = MLUDynamicNTKAlphaRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
)
|
||||
elif "factor" in rope_scaling:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
rotary_emb = MLUDynamicNTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"apply_yarn_scaling",
|
||||
)
|
||||
}
|
||||
if "mrope_section" in rope_scaling:
|
||||
extra_kwargs.pop("apply_yarn_scaling", None)
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
|
||||
scaling_factor=scaling_factor,
|
||||
**extra_kwargs,
|
||||
)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: update original_max_position
|
||||
'''
|
||||
original_max_position = get_long_max_model_max_position_emb(
|
||||
original_max_position, scaling_factor,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
)
|
||||
}
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: update original_max_position
|
||||
'''
|
||||
original_max_position = get_long_max_model_max_position_emb(
|
||||
original_max_position, scaling_factor,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
rotary_emb = MLUDeepseekScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
inverse,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "longrope":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
short_factor,
|
||||
long_factor,
|
||||
**extra_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
rotary_embedding,
|
||||
rotary_embedding.get_rope,
|
||||
vllm__model_executor__layers__rotary_embedding__get_rope,
|
||||
)
|
||||
302
vllm_mlu/model_executor/layers/rotary_embedding/base.py
Normal file
302
vllm_mlu/model_executor/layers/rotary_embedding/base.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.v1.attention.backends.utils import (
|
||||
get_common_metadata,
|
||||
MLUCommonAttentionMetadata,
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata
|
||||
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding_mlu")
|
||||
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
|
||||
|
||||
cu_seq_lens : torch.Tensor = None
|
||||
max_seq_len : int = None
|
||||
max_model_len : int = None
|
||||
is_prompt : bool = False
|
||||
is_chunked : bool = False
|
||||
positions_: torch.Tensor = None
|
||||
chunked_prefill_enabled: bool = False
|
||||
prefill_cu_seq_lens: torch.Tensor = None
|
||||
prefill_max_seq_len: int = None
|
||||
decode_cu_seq_lens: torch.Tensor = None
|
||||
decode_max_seq_len: int = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
inverse: bool = False,
|
||||
) -> None:
|
||||
CustomOp.__init__(self)
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
# TODO(mgoin): disabled for now due to failures
|
||||
# Flashinfer only supports head_size=64, 128, 256, 512.
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
|
||||
# self.use_flashinfer = (self.enabled()
|
||||
# and dtype in (torch.float16, torch.bfloat16)
|
||||
# and current_platform.is_cuda()
|
||||
# and has_flashinfer()
|
||||
# and self.head_size in [64, 128, 256, 512])
|
||||
self.use_flashinfer = False
|
||||
self.inverse = inverse
|
||||
|
||||
# For vlm v1
|
||||
# 1. mlu rope run in eager mode
|
||||
# 2. all layer use layer0's rope to inference
|
||||
prefix = "global_rope"
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.use_direct_call = False
|
||||
if not self.use_direct_call:
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
pass
|
||||
else:
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding
|
||||
|
||||
if MLURotaryEmbedding.max_seq_len != None \
|
||||
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
|
||||
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
|
||||
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
|
||||
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
|
||||
f"This may lead to incorrect model outputs or MLU errors. " +
|
||||
f"Make sure the value is correct and within the model context size. " +
|
||||
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
|
||||
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
|
||||
cache = self._compute_cos_sin_cache()
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding
|
||||
if isinstance(self, MLULinearScalingRotaryEmbedding):
|
||||
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
|
||||
elif is_neox_style:
|
||||
cache_pos = cache.shape[0]
|
||||
cache = cache.reshape(cache_pos, 2, -1)
|
||||
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
|
||||
else:
|
||||
cache = cache.repeat_interleave(2, dim=-1)
|
||||
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.cos_, self.sin_ = self._get_cos_sin()
|
||||
@classmethod
|
||||
def set_mlu_var_v1(
|
||||
cls,
|
||||
common_metadata: MLUCommonAttentionMetadata
|
||||
) -> None:
|
||||
cls.unset_mlu_var()
|
||||
cls.cu_seq_lens = common_metadata.query_start_loc
|
||||
cls.max_seq_len = common_metadata.max_query_len
|
||||
cls.is_prompt = common_metadata.is_prefill_only
|
||||
cls.is_chunked = common_metadata.is_chunked
|
||||
|
||||
# for MLA
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
_, attn_metadata = next(iter(attn_metadata.items()))
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
decode_metadata = attn_metadata.decode
|
||||
if prefill_metadata:
|
||||
cls.prefill_max_seq_len = prefill_metadata.max_query_len
|
||||
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
|
||||
else:
|
||||
cls.prefill_max_seq_len = cls.max_seq_len
|
||||
cls.prefill_cu_seq_lens = cls.cu_seq_lens
|
||||
|
||||
if decode_metadata:
|
||||
cls.decode_max_seq_len = decode_metadata.max_query_len
|
||||
cls.decode_cu_seq_lens = decode_metadata.query_start_loc
|
||||
else:
|
||||
cls.decode_max_seq_len = cls.max_seq_len
|
||||
cls.decode_cu_seq_lens = cls.cu_seq_lens
|
||||
|
||||
# for sp
|
||||
sp_context = get_sp_forward_context()
|
||||
if sp_context is not None and sp_context.is_v32:
|
||||
prefill_metadata = sp_context.sp_attn_metadata.prefill
|
||||
cls.is_chunked = True
|
||||
cls.prefill_max_seq_len = prefill_metadata.max_query_len
|
||||
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
|
||||
|
||||
@classmethod
|
||||
def unset_mlu_var(cls):
|
||||
cls.cu_seq_lens = None
|
||||
cls.max_seq_len = None
|
||||
cls.is_prompt = False
|
||||
cls.is_chunked = False
|
||||
cls.positions_ = None
|
||||
cls.chunked_prefill_enabled = False
|
||||
cls.prefill_cu_seq_lens = None
|
||||
cls.prefill_max_seq_len = None
|
||||
cls.decode_cu_seq_lens = None
|
||||
cls.decode_max_seq_len = None
|
||||
|
||||
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
|
||||
sin = sin.view(-1, self.rotary_dim)
|
||||
cos = cos.view(-1, self.rotary_dim)
|
||||
return cos, sin
|
||||
|
||||
def _get_positions_with_offsets_mlu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
offsets: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if offsets.numel() != positions.numel():
|
||||
raise Exception("rope offsets numel mismatch with positions, "
|
||||
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
|
||||
return (positions + offsets).to(torch.int32)
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
|
||||
if common_metadata is None:
|
||||
num_tokens, head_num, head_size = x.shape
|
||||
x = mlu_ops.rotary_embedding(
|
||||
x.view(1, num_tokens, head_num, head_size),
|
||||
self.sin_,
|
||||
self.cos_,
|
||||
positions,
|
||||
None,
|
||||
not self.is_neox_style,
|
||||
True,
|
||||
False,
|
||||
num_tokens
|
||||
)
|
||||
return x
|
||||
else:
|
||||
cu_seq_lens_ = common_metadata.query_start_loc
|
||||
|
||||
if offsets is not None:
|
||||
if MLURotaryEmbedding.positions_ is None:
|
||||
MLURotaryEmbedding.positions_ = (
|
||||
self._get_positions_with_offsets_mlu(positions, offsets))
|
||||
position_ids = MLURotaryEmbedding.positions_
|
||||
discrete = True
|
||||
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else:
|
||||
position_ids = None
|
||||
discrete = False
|
||||
|
||||
x = mlu_ops.rotary_embedding(
|
||||
x,
|
||||
self.sin_,
|
||||
self.cos_,
|
||||
position_ids,
|
||||
cu_seq_lens_,
|
||||
not self.is_neox_style,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len
|
||||
)
|
||||
return x
|
||||
|
||||
def get_param(self, positions, discrete=False):
|
||||
interleaved = True
|
||||
if self.is_neox_style:
|
||||
interleaved = False
|
||||
|
||||
if discrete:
|
||||
position_ids = positions
|
||||
discrete = discrete
|
||||
else:
|
||||
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else:
|
||||
position_ids = None
|
||||
discrete = False
|
||||
|
||||
return position_ids, interleaved, discrete
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
cos = freqs_cis.real
|
||||
sin = freqs_cis.imag * (-1 if self.inverse else 1)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor | None = None,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
only_prefill: bool | None = False,
|
||||
only_decode: bool | None = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.forward_impl(positions, query, offsets)
|
||||
if key is not None:
|
||||
self.forward_impl(positions, key, offsets)
|
||||
return query, key
|
||||
|
||||
|
||||
def rope_forward(
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
layer_name: str,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
self.forward_impl(positions, x, offsets)
|
||||
|
||||
|
||||
def rope_forward_fake(
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
layer_name: str,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rope_forward",
|
||||
op_func=rope_forward,
|
||||
mutates_args=["x"],
|
||||
fake_impl=rope_forward_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
yarn_get_mscale,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
rotate_gptj,
|
||||
rotate_neox,
|
||||
yarn_find_correction_range,
|
||||
yarn_linear_ramp_mask,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLUDeepseekScalingRotaryEmbedding(MLURotaryEmbedding, DeepseekScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
inverse: bool = False,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
mscale: float = 1,
|
||||
mscale_all_dim: float = 0,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation.
|
||||
self.mscale = float(
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale)) /
|
||||
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
|
||||
attn_factor)
|
||||
self.inverse = inverse
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
def forward_mlu_rot(self, input, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len):
|
||||
"""only one input rotary implementation"""
|
||||
if input is None:
|
||||
return None
|
||||
if self.rotary_dim < self.head_size:
|
||||
input_pass = input[..., self.rotary_dim:]
|
||||
input_rot = input[..., :self.rotary_dim]
|
||||
input_rot = mlu_ops.rotary_embedding(
|
||||
input_rot,
|
||||
self.sin_,
|
||||
self.cos_,
|
||||
position_ids,
|
||||
cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
False,
|
||||
max_seq_len
|
||||
)
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
input = torch.cat((input_rot, input_pass), dim=-1)
|
||||
else:
|
||||
input = input_rot
|
||||
|
||||
return input
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor | None = None,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
only_prefill: bool | None = False,
|
||||
only_decode: bool | None = False,
|
||||
discrete: bool | None = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
position_ids, interleaved, discrete = self.get_param(positions, discrete)
|
||||
|
||||
cu_seq_lens = MLURotaryEmbedding.cu_seq_lens
|
||||
max_seq_len = MLURotaryEmbedding.max_seq_len
|
||||
|
||||
# for MLA
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
_, attn_metadata = next(iter(attn_metadata.items()))
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
if only_prefill:
|
||||
cu_seq_lens = MLURotaryEmbedding.prefill_cu_seq_lens
|
||||
max_seq_len = MLURotaryEmbedding.prefill_max_seq_len
|
||||
elif only_decode:
|
||||
cu_seq_lens = MLURotaryEmbedding.decode_cu_seq_lens
|
||||
max_seq_len = MLURotaryEmbedding.decode_max_seq_len
|
||||
|
||||
query = self.forward_mlu_rot(query, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
|
||||
key = self.forward_mlu_rot(key, position_ids, interleaved, discrete, cu_seq_lens, max_seq_len)
|
||||
|
||||
return query, key
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base ** (
|
||||
torch.arange(
|
||||
0,
|
||||
self.rotary_dim,
|
||||
2,
|
||||
dtype=torch.float,
|
||||
device=current_platform.device_type,
|
||||
)
|
||||
/ self.rotary_dim
|
||||
)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
self.rotary_dim,
|
||||
self.base,
|
||||
self.max_position_embeddings,
|
||||
)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
device = current_platform.device_type
|
||||
inv_freq_mask = ((
|
||||
1
|
||||
- yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
|
||||
) * self.extrapolation_factor).to(device)
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_mask)
|
||||
+ inv_freq_extrapolation * inv_freq_mask
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(
|
||||
self.max_position_embeddings * self.scaling_factor,
|
||||
device=current_platform.device_type,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale * (-1 if self.inverse else 1)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
forward = MLURotaryEmbedding.forward
|
||||
forward_native = forward_oot
|
||||
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLUDynamicNTKAlphaRotaryEmbedding(MLURotaryEmbedding, DynamicNTKAlphaRotaryEmbedding):
|
||||
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_alpha: float,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.scaling_alpha = scaling_alpha
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLUDynamicNTKScalingRotaryEmbedding(MLURotaryEmbedding, DynamicNTKScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.model_executor.layers.rotary_embedding.linear_scaling_rope import LinearScalingRotaryEmbedding
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
class MLULinearScalingRotaryEmbedding(MLURotaryEmbedding, LinearScalingRotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_factors: list[float] | float,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
if isinstance(scaling_factors, float):
|
||||
scaling_factors = [scaling_factors]
|
||||
self.scaling_factors: list[float] = scaling_factors # noqa
|
||||
MLURotaryEmbedding.__init__(
|
||||
self, head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
# Lazy initialized.
|
||||
self._scaling_factor_to_offset: dict[float, int]
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
device = current_platform.device_type
|
||||
if self.is_neox_style:
|
||||
half_dim = self.rotary_dim // 2
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
|
||||
% half_dim * 2 / self.rotary_dim)
|
||||
)
|
||||
else:
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, self.rotary_dim, 1, dtype=torch.float32, device=device)
|
||||
// 2 * 2 / self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
cache_list: list[torch.Tensor] = []
|
||||
# offsets to the next cache in a tensor.
|
||||
# Each offset corresponds to the same index in scaling_factors.
|
||||
offsets: list[int] = []
|
||||
device = current_platform.device_type
|
||||
for scaling_factor in self.scaling_factors:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * scaling_factor
|
||||
t = torch.arange(max_len, dtype=torch.float, device=device)
|
||||
t = t / scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
if not cache_list:
|
||||
offset = 0
|
||||
else:
|
||||
last_offset = offsets[-1]
|
||||
next_max_len = cache_list[-1].shape[0]
|
||||
offset = last_offset + next_max_len
|
||||
offsets.append(offset)
|
||||
cache_list.append(cache)
|
||||
self._scaling_factor_to_offset = {
|
||||
float(scaling_factor): offsets[i]
|
||||
for i, scaling_factor in enumerate(self.scaling_factors)
|
||||
}
|
||||
assert len(self.scaling_factors) == len(offsets)
|
||||
return torch.cat(cache_list, dim=0)
|
||||
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLULlama3RotaryEmbedding(MLURotaryEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
scaling_factor: float,
|
||||
low_freq_factor: float,
|
||||
high_freq_factor: float,
|
||||
orig_max_position: int,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.low_freq_factor = low_freq_factor
|
||||
self.high_freq_factor = high_freq_factor
|
||||
self.orig_max_position = orig_max_position
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
140
vllm_mlu/model_executor/layers/rotary_embedding/mrope.py
Normal file
140
vllm_mlu/model_executor/layers/rotary_embedding/mrope.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.common import yarn_get_mscale
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.base import MLURotaryEmbedding
|
||||
|
||||
|
||||
class MLUMRotaryEmbedding(MLURotaryEmbedding, MRotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
mrope_section: list[int] | None = None,
|
||||
mrope_interleaved: bool = False,
|
||||
# YaRN parameters.
|
||||
*,
|
||||
scaling_factor: float | None = None,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
if self.scaling_factor is not None:
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
else:
|
||||
self.mscale = 1.0
|
||||
|
||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||
# a larger the cos and sin cache.
|
||||
self.cache_max_position_num = max_position_embeddings * 4
|
||||
MLURotaryEmbedding.__init__(
|
||||
self,
|
||||
head_size,
|
||||
rotary_dim,
|
||||
self.cache_max_position_num,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
|
||||
self.mrope_section = mrope_section
|
||||
self.mrope_interleaved = mrope_interleaved
|
||||
if self.mrope_section:
|
||||
assert sum(self.mrope_section) == rotary_dim // 2
|
||||
|
||||
def _apply_mrope(self, positions):
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
num_section = len(self.mrope_section)
|
||||
mrope_section = self.mrope_section * 2
|
||||
def _apply(x):
|
||||
x = torch.cat([
|
||||
m[i % num_section]
|
||||
for i, m in enumerate(x.split(mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
return x
|
||||
return _apply(cos), _apply(sin)
|
||||
|
||||
def _apply_interleaved_mrope(self, positions):
|
||||
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
||||
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
||||
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
||||
"""
|
||||
mrope_section = self.mrope_section
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
def _apply(x):
|
||||
x_t = x[0].clone()
|
||||
x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
|
||||
x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
|
||||
offset = self.rotary_dim // 2
|
||||
x_t[..., 1 + offset:mrope_section[1] * 3 + offset:3] = x[1, ..., 1 + offset:mrope_section[1] * 3 + offset:3]
|
||||
x_t[..., 2 + offset:mrope_section[2] * 3 + offset:3] = x[2, ..., 2 + offset:mrope_section[2] * 3 + offset:3]
|
||||
return x_t
|
||||
return _apply(cos), _apply(sin)
|
||||
|
||||
def precompute_sin_cos_cache(
|
||||
self,
|
||||
positions: torch.Tensor
|
||||
):
|
||||
'''
|
||||
call this function before forward decoder layers
|
||||
precompute sin/cos cache for mrope
|
||||
'''
|
||||
if positions.ndim == 1:
|
||||
return
|
||||
assert positions.ndim == 2
|
||||
assert self.mrope_section
|
||||
if self.mrope_interleaved:
|
||||
cos, sin = self._apply_interleaved_mrope(positions)
|
||||
else:
|
||||
cos, sin = self._apply_mrope(positions)
|
||||
self.mrope_cos_cache = cos
|
||||
self.mrope_sin_cache = sin
|
||||
self.mrope_cu_seq_lens = torch.zeros(2, dtype=torch.int32, device=positions.device)
|
||||
num_tokens = positions.shape[-1]
|
||||
self.mrope_cu_seq_lens[1] = num_tokens
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
if positions.ndim == 1:
|
||||
return MLURotaryEmbedding.forward_oot(self, positions, x)
|
||||
assert self.mrope_cos_cache is not None and self.mrope_sin_cache is not None,\
|
||||
"call precompute_sin_cos_cache first!"
|
||||
num_tokens = positions.shape[-1]
|
||||
x = mlu_ops.rotary_embedding(x,
|
||||
self.mrope_sin_cache,
|
||||
self.mrope_cos_cache,
|
||||
None,
|
||||
self.mrope_cu_seq_lens,
|
||||
not self.is_neox_style,
|
||||
False,
|
||||
False,
|
||||
num_tokens)
|
||||
return x
|
||||
|
||||
forward = MLURotaryEmbedding.forward
|
||||
Reference in New Issue
Block a user