Files

303 lines
11 KiB
Python
Raw Permalink Normal View History

2026-04-24 09:50:34 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Tuple
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.utils import (
get_common_metadata,
MLUCommonAttentionMetadata,
)
from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
logger = init_logger(__name__)
@CustomOp.register("rotary_embedding_mlu")
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
cu_seq_lens : torch.Tensor = None
max_seq_len : int = None
max_model_len : int = None
is_prompt : bool = False
is_chunked : bool = False
positions_: torch.Tensor = None
chunked_prefill_enabled: bool = False
prefill_cu_seq_lens: torch.Tensor = None
prefill_max_seq_len: int = None
decode_cu_seq_lens: torch.Tensor = None
decode_max_seq_len: int = None
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
inverse: bool = False,
) -> None:
CustomOp.__init__(self)
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
# TODO(mgoin): disabled for now due to failures
# Flashinfer only supports head_size=64, 128, 256, 512.
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
# self.use_flashinfer = (self.enabled()
# and dtype in (torch.float16, torch.bfloat16)
# and current_platform.is_cuda()
# and has_flashinfer()
# and self.head_size in [64, 128, 256, 512])
self.use_flashinfer = False
self.inverse = inverse
# For vlm v1
# 1. mlu rope run in eager mode
# 2. all layer use layer0's rope to inference
prefix = "global_rope"
vllm_config = get_current_vllm_config()
self.use_direct_call = False
if not self.use_direct_call:
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
pass
else:
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding
if MLURotaryEmbedding.max_seq_len != None \
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
f"This may lead to incorrect model outputs or MLU errors. " +
f"Make sure the value is correct and within the model context size. " +
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
cache = self._compute_cos_sin_cache()
from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding
if isinstance(self, MLULinearScalingRotaryEmbedding):
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
elif is_neox_style:
cache_pos = cache.shape[0]
cache = cache.reshape(cache_pos, 2, -1)
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
else:
cache = cache.repeat_interleave(2, dim=-1)
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.cos_, self.sin_ = self._get_cos_sin()
@classmethod
def set_mlu_var_v1(
cls,
common_metadata: MLUCommonAttentionMetadata
) -> None:
cls.unset_mlu_var()
cls.cu_seq_lens = common_metadata.query_start_loc
cls.max_seq_len = common_metadata.max_query_len
cls.is_prompt = common_metadata.is_prefill_only
cls.is_chunked = common_metadata.is_chunked
# for MLA
attn_metadata = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
_, attn_metadata = next(iter(attn_metadata.items()))
if isinstance(attn_metadata, MLACommonMetadata):
prefill_metadata = attn_metadata.prefill
decode_metadata = attn_metadata.decode
if prefill_metadata:
cls.prefill_max_seq_len = prefill_metadata.max_query_len
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
else:
cls.prefill_max_seq_len = cls.max_seq_len
cls.prefill_cu_seq_lens = cls.cu_seq_lens
if decode_metadata:
cls.decode_max_seq_len = decode_metadata.max_query_len
cls.decode_cu_seq_lens = decode_metadata.query_start_loc
else:
cls.decode_max_seq_len = cls.max_seq_len
cls.decode_cu_seq_lens = cls.cu_seq_lens
# for sp
sp_context = get_sp_forward_context()
if sp_context is not None and sp_context.is_v32:
prefill_metadata = sp_context.sp_attn_metadata.prefill
cls.is_chunked = True
cls.prefill_max_seq_len = prefill_metadata.max_query_len
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
@classmethod
def unset_mlu_var(cls):
cls.cu_seq_lens = None
cls.max_seq_len = None
cls.is_prompt = False
cls.is_chunked = False
cls.positions_ = None
cls.chunked_prefill_enabled = False
cls.prefill_cu_seq_lens = None
cls.prefill_max_seq_len = None
cls.decode_cu_seq_lens = None
cls.decode_max_seq_len = None
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
sin = sin.view(-1, self.rotary_dim)
cos = cos.view(-1, self.rotary_dim)
return cos, sin
def _get_positions_with_offsets_mlu(
self,
positions: torch.Tensor,
offsets: torch.Tensor
) -> torch.Tensor:
if offsets.numel() != positions.numel():
raise Exception("rope offsets numel mismatch with positions, "
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
return (positions + offsets).to(torch.int32)
def forward_impl(
self,
positions: torch.Tensor,
x: torch.Tensor,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
if common_metadata is None:
num_tokens, head_num, head_size = x.shape
x = mlu_ops.rotary_embedding(
x.view(1, num_tokens, head_num, head_size),
self.sin_,
self.cos_,
positions,
None,
not self.is_neox_style,
True,
False,
num_tokens
)
return x
else:
cu_seq_lens_ = common_metadata.query_start_loc
if offsets is not None:
if MLURotaryEmbedding.positions_ is None:
MLURotaryEmbedding.positions_ = (
self._get_positions_with_offsets_mlu(positions, offsets))
position_ids = MLURotaryEmbedding.positions_
discrete = True
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
x = mlu_ops.rotary_embedding(
x,
self.sin_,
self.cos_,
position_ids,
cu_seq_lens_,
not self.is_neox_style,
discrete,
False,
MLURotaryEmbedding.max_seq_len
)
return x
def get_param(self, positions, discrete=False):
interleaved = True
if self.is_neox_style:
interleaved = False
if discrete:
position_ids = positions
discrete = discrete
else:
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
position_ids = positions
discrete = True
else:
position_ids = None
discrete = False
return position_ids, interleaved, discrete
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.outer(t, inv_freq)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cos = freqs_cis.real
sin = freqs_cis.imag * (-1 if self.inverse else 1)
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor | None = None,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
only_prefill: bool | None = False,
only_decode: bool | None = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.forward_impl(positions, query, offsets)
if key is not None:
self.forward_impl(positions, key, offsets)
return query, key
def rope_forward(
positions: torch.Tensor,
x: torch.Tensor,
layer_name: str,
offsets: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_impl(positions, x, offsets)
def rope_forward_fake(
positions: torch.Tensor,
x: torch.Tensor,
layer_name: str,
offsets: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="rope_forward",
op_func=rope_forward,
mutates_args=["x"],
fake_impl=rope_forward_fake,
dispatch_key=current_platform.dispatch_key,
)