303 lines
11 KiB
Python
303 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from typing import Tuple
|
|
import torch
|
|
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
|
from vllm_mlu.v1.attention.backends.utils import (
|
|
get_common_metadata,
|
|
MLUCommonAttentionMetadata,
|
|
)
|
|
from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata
|
|
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@CustomOp.register("rotary_embedding_mlu")
|
|
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
|
|
|
|
cu_seq_lens : torch.Tensor = None
|
|
max_seq_len : int = None
|
|
max_model_len : int = None
|
|
is_prompt : bool = False
|
|
is_chunked : bool = False
|
|
positions_: torch.Tensor = None
|
|
chunked_prefill_enabled: bool = False
|
|
prefill_cu_seq_lens: torch.Tensor = None
|
|
prefill_max_seq_len: int = None
|
|
decode_cu_seq_lens: torch.Tensor = None
|
|
decode_max_seq_len: int = None
|
|
|
|
def __init__(
|
|
self,
|
|
head_size: int,
|
|
rotary_dim: int,
|
|
max_position_embeddings: int,
|
|
base: float,
|
|
is_neox_style: bool,
|
|
dtype: torch.dtype,
|
|
inverse: bool = False,
|
|
) -> None:
|
|
CustomOp.__init__(self)
|
|
self.head_size = head_size
|
|
self.rotary_dim = rotary_dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
self.is_neox_style = is_neox_style
|
|
self.dtype = dtype
|
|
# TODO(mgoin): disabled for now due to failures
|
|
# Flashinfer only supports head_size=64, 128, 256, 512.
|
|
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
|
|
# self.use_flashinfer = (self.enabled()
|
|
# and dtype in (torch.float16, torch.bfloat16)
|
|
# and current_platform.is_cuda()
|
|
# and has_flashinfer()
|
|
# and self.head_size in [64, 128, 256, 512])
|
|
self.use_flashinfer = False
|
|
self.inverse = inverse
|
|
|
|
# For vlm v1
|
|
# 1. mlu rope run in eager mode
|
|
# 2. all layer use layer0's rope to inference
|
|
prefix = "global_rope"
|
|
vllm_config = get_current_vllm_config()
|
|
self.use_direct_call = False
|
|
if not self.use_direct_call:
|
|
compilation_config = vllm_config.compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
pass
|
|
else:
|
|
compilation_config.static_forward_context[prefix] = self
|
|
self.layer_name = prefix
|
|
|
|
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
|
|
from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding
|
|
|
|
if MLURotaryEmbedding.max_seq_len != None \
|
|
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
|
|
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
|
|
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
|
|
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
|
|
f"This may lead to incorrect model outputs or MLU errors. " +
|
|
f"Make sure the value is correct and within the model context size. " +
|
|
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
|
|
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
|
|
cache = self._compute_cos_sin_cache()
|
|
|
|
from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding
|
|
if isinstance(self, MLULinearScalingRotaryEmbedding):
|
|
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
|
|
elif is_neox_style:
|
|
cache_pos = cache.shape[0]
|
|
cache = cache.reshape(cache_pos, 2, -1)
|
|
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
|
|
else:
|
|
cache = cache.repeat_interleave(2, dim=-1)
|
|
|
|
cache = cache.to(dtype)
|
|
self.cos_sin_cache: torch.Tensor
|
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
self.cos_, self.sin_ = self._get_cos_sin()
|
|
@classmethod
|
|
def set_mlu_var_v1(
|
|
cls,
|
|
common_metadata: MLUCommonAttentionMetadata
|
|
) -> None:
|
|
cls.unset_mlu_var()
|
|
cls.cu_seq_lens = common_metadata.query_start_loc
|
|
cls.max_seq_len = common_metadata.max_query_len
|
|
cls.is_prompt = common_metadata.is_prefill_only
|
|
cls.is_chunked = common_metadata.is_chunked
|
|
|
|
# for MLA
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
if isinstance(attn_metadata, dict):
|
|
_, attn_metadata = next(iter(attn_metadata.items()))
|
|
if isinstance(attn_metadata, MLACommonMetadata):
|
|
prefill_metadata = attn_metadata.prefill
|
|
decode_metadata = attn_metadata.decode
|
|
if prefill_metadata:
|
|
cls.prefill_max_seq_len = prefill_metadata.max_query_len
|
|
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
|
|
else:
|
|
cls.prefill_max_seq_len = cls.max_seq_len
|
|
cls.prefill_cu_seq_lens = cls.cu_seq_lens
|
|
|
|
if decode_metadata:
|
|
cls.decode_max_seq_len = decode_metadata.max_query_len
|
|
cls.decode_cu_seq_lens = decode_metadata.query_start_loc
|
|
else:
|
|
cls.decode_max_seq_len = cls.max_seq_len
|
|
cls.decode_cu_seq_lens = cls.cu_seq_lens
|
|
|
|
# for sp
|
|
sp_context = get_sp_forward_context()
|
|
if sp_context is not None and sp_context.is_v32:
|
|
prefill_metadata = sp_context.sp_attn_metadata.prefill
|
|
cls.is_chunked = True
|
|
cls.prefill_max_seq_len = prefill_metadata.max_query_len
|
|
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
|
|
|
|
@classmethod
|
|
def unset_mlu_var(cls):
|
|
cls.cu_seq_lens = None
|
|
cls.max_seq_len = None
|
|
cls.is_prompt = False
|
|
cls.is_chunked = False
|
|
cls.positions_ = None
|
|
cls.chunked_prefill_enabled = False
|
|
cls.prefill_cu_seq_lens = None
|
|
cls.prefill_max_seq_len = None
|
|
cls.decode_cu_seq_lens = None
|
|
cls.decode_max_seq_len = None
|
|
|
|
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
|
|
sin = sin.view(-1, self.rotary_dim)
|
|
cos = cos.view(-1, self.rotary_dim)
|
|
return cos, sin
|
|
|
|
def _get_positions_with_offsets_mlu(
|
|
self,
|
|
positions: torch.Tensor,
|
|
offsets: torch.Tensor
|
|
) -> torch.Tensor:
|
|
if offsets.numel() != positions.numel():
|
|
raise Exception("rope offsets numel mismatch with positions, "
|
|
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
|
|
return (positions + offsets).to(torch.int32)
|
|
|
|
def forward_impl(
|
|
self,
|
|
positions: torch.Tensor,
|
|
x: torch.Tensor,
|
|
offsets: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
|
|
if common_metadata is None:
|
|
num_tokens, head_num, head_size = x.shape
|
|
x = mlu_ops.rotary_embedding(
|
|
x.view(1, num_tokens, head_num, head_size),
|
|
self.sin_,
|
|
self.cos_,
|
|
positions,
|
|
None,
|
|
not self.is_neox_style,
|
|
True,
|
|
False,
|
|
num_tokens
|
|
)
|
|
return x
|
|
else:
|
|
cu_seq_lens_ = common_metadata.query_start_loc
|
|
|
|
if offsets is not None:
|
|
if MLURotaryEmbedding.positions_ is None:
|
|
MLURotaryEmbedding.positions_ = (
|
|
self._get_positions_with_offsets_mlu(positions, offsets))
|
|
position_ids = MLURotaryEmbedding.positions_
|
|
discrete = True
|
|
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
|
position_ids = positions
|
|
discrete = True
|
|
else:
|
|
position_ids = None
|
|
discrete = False
|
|
|
|
x = mlu_ops.rotary_embedding(
|
|
x,
|
|
self.sin_,
|
|
self.cos_,
|
|
position_ids,
|
|
cu_seq_lens_,
|
|
not self.is_neox_style,
|
|
discrete,
|
|
False,
|
|
MLURotaryEmbedding.max_seq_len
|
|
)
|
|
return x
|
|
|
|
def get_param(self, positions, discrete=False):
|
|
interleaved = True
|
|
if self.is_neox_style:
|
|
interleaved = False
|
|
|
|
if discrete:
|
|
position_ids = positions
|
|
discrete = discrete
|
|
else:
|
|
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
|
position_ids = positions
|
|
discrete = True
|
|
else:
|
|
position_ids = None
|
|
discrete = False
|
|
|
|
return position_ids, interleaved, discrete
|
|
|
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
|
"""Compute the cos and sin cache."""
|
|
inv_freq = self._compute_inv_freq(self.base)
|
|
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
|
|
|
freqs = torch.outer(t, inv_freq)
|
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
cos = freqs_cis.real
|
|
sin = freqs_cis.imag * (-1 if self.inverse else 1)
|
|
cache = torch.cat((cos, sin), dim=-1)
|
|
return cache
|
|
|
|
def forward_oot(
|
|
self,
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor | None = None,
|
|
key: torch.Tensor | None = None,
|
|
offsets: torch.Tensor | None = None,
|
|
only_prefill: bool | None = False,
|
|
only_decode: bool | None = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
self.forward_impl(positions, query, offsets)
|
|
if key is not None:
|
|
self.forward_impl(positions, key, offsets)
|
|
return query, key
|
|
|
|
|
|
def rope_forward(
|
|
positions: torch.Tensor,
|
|
x: torch.Tensor,
|
|
layer_name: str,
|
|
offsets: torch.Tensor | None = None,
|
|
) -> None:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
self = forward_context.no_compile_layers[layer_name]
|
|
|
|
self.forward_impl(positions, x, offsets)
|
|
|
|
|
|
def rope_forward_fake(
|
|
positions: torch.Tensor,
|
|
x: torch.Tensor,
|
|
layer_name: str,
|
|
offsets: torch.Tensor | None = None,
|
|
) -> None:
|
|
return
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="rope_forward",
|
|
op_func=rope_forward,
|
|
mutates_args=["x"],
|
|
fake_impl=rope_forward_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
)
|