[Model] Support DeepSeek-V4
This commit is contained in:
302
vllm_mlu/model_executor/layers/rotary_embedding/base.py
Normal file
302
vllm_mlu/model_executor/layers/rotary_embedding/base.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
import torch
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.v1.attention.backends.utils import (
|
||||
get_common_metadata,
|
||||
MLUCommonAttentionMetadata,
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.mla.flashmla import MLACommonMetadata
|
||||
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding_mlu")
|
||||
class MLURotaryEmbedding(RotaryEmbedding, CustomOp):
|
||||
|
||||
cu_seq_lens : torch.Tensor = None
|
||||
max_seq_len : int = None
|
||||
max_model_len : int = None
|
||||
is_prompt : bool = False
|
||||
is_chunked : bool = False
|
||||
positions_: torch.Tensor = None
|
||||
chunked_prefill_enabled: bool = False
|
||||
prefill_cu_seq_lens: torch.Tensor = None
|
||||
prefill_max_seq_len: int = None
|
||||
decode_cu_seq_lens: torch.Tensor = None
|
||||
decode_max_seq_len: int = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
inverse: bool = False,
|
||||
) -> None:
|
||||
CustomOp.__init__(self)
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
# TODO(mgoin): disabled for now due to failures
|
||||
# Flashinfer only supports head_size=64, 128, 256, 512.
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
|
||||
# self.use_flashinfer = (self.enabled()
|
||||
# and dtype in (torch.float16, torch.bfloat16)
|
||||
# and current_platform.is_cuda()
|
||||
# and has_flashinfer()
|
||||
# and self.head_size in [64, 128, 256, 512])
|
||||
self.use_flashinfer = False
|
||||
self.inverse = inverse
|
||||
|
||||
# For vlm v1
|
||||
# 1. mlu rope run in eager mode
|
||||
# 2. all layer use layer0's rope to inference
|
||||
prefix = "global_rope"
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.use_direct_call = False
|
||||
if not self.use_direct_call:
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
pass
|
||||
else:
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import DeepseekScalingRotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding.yarn_scaling_rope import YaRNScalingRotaryEmbedding
|
||||
|
||||
if MLURotaryEmbedding.max_seq_len != None \
|
||||
and self.max_position_embeddings < MLURotaryEmbedding.max_seq_len and \
|
||||
not isinstance(self, (YaRNScalingRotaryEmbedding, DeepseekScalingRotaryEmbedding)):
|
||||
logger.warning(f"User-specified max_model_len ({MLURotaryEmbedding.max_seq_len}) is different with " +
|
||||
f"max_position_embedding ({max_position_embeddings}) from model's config.json, " +
|
||||
f"This may lead to incorrect model outputs or MLU errors. " +
|
||||
f"Make sure the value is correct and within the model context size. " +
|
||||
f"Set max_position_embedding={MLURotaryEmbedding.max_seq_len}.")
|
||||
self.max_position_embeddings = MLURotaryEmbedding.max_seq_len
|
||||
cache = self._compute_cos_sin_cache()
|
||||
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding.linear_scaling_rope import MLULinearScalingRotaryEmbedding
|
||||
if isinstance(self, MLULinearScalingRotaryEmbedding):
|
||||
logger.debug(f"Using mlu defining _compute_cos_sin_cache due to the special tensor composition")
|
||||
elif is_neox_style:
|
||||
cache_pos = cache.shape[0]
|
||||
cache = cache.reshape(cache_pos, 2, -1)
|
||||
cache = torch.tile(cache, (1, 1, 2)).reshape(cache_pos, -1)
|
||||
else:
|
||||
cache = cache.repeat_interleave(2, dim=-1)
|
||||
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.cos_, self.sin_ = self._get_cos_sin()
|
||||
@classmethod
|
||||
def set_mlu_var_v1(
|
||||
cls,
|
||||
common_metadata: MLUCommonAttentionMetadata
|
||||
) -> None:
|
||||
cls.unset_mlu_var()
|
||||
cls.cu_seq_lens = common_metadata.query_start_loc
|
||||
cls.max_seq_len = common_metadata.max_query_len
|
||||
cls.is_prompt = common_metadata.is_prefill_only
|
||||
cls.is_chunked = common_metadata.is_chunked
|
||||
|
||||
# for MLA
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
_, attn_metadata = next(iter(attn_metadata.items()))
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
decode_metadata = attn_metadata.decode
|
||||
if prefill_metadata:
|
||||
cls.prefill_max_seq_len = prefill_metadata.max_query_len
|
||||
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
|
||||
else:
|
||||
cls.prefill_max_seq_len = cls.max_seq_len
|
||||
cls.prefill_cu_seq_lens = cls.cu_seq_lens
|
||||
|
||||
if decode_metadata:
|
||||
cls.decode_max_seq_len = decode_metadata.max_query_len
|
||||
cls.decode_cu_seq_lens = decode_metadata.query_start_loc
|
||||
else:
|
||||
cls.decode_max_seq_len = cls.max_seq_len
|
||||
cls.decode_cu_seq_lens = cls.cu_seq_lens
|
||||
|
||||
# for sp
|
||||
sp_context = get_sp_forward_context()
|
||||
if sp_context is not None and sp_context.is_v32:
|
||||
prefill_metadata = sp_context.sp_attn_metadata.prefill
|
||||
cls.is_chunked = True
|
||||
cls.prefill_max_seq_len = prefill_metadata.max_query_len
|
||||
cls.prefill_cu_seq_lens = prefill_metadata.query_start_loc
|
||||
|
||||
@classmethod
|
||||
def unset_mlu_var(cls):
|
||||
cls.cu_seq_lens = None
|
||||
cls.max_seq_len = None
|
||||
cls.is_prompt = False
|
||||
cls.is_chunked = False
|
||||
cls.positions_ = None
|
||||
cls.chunked_prefill_enabled = False
|
||||
cls.prefill_cu_seq_lens = None
|
||||
cls.prefill_max_seq_len = None
|
||||
cls.decode_cu_seq_lens = None
|
||||
cls.decode_max_seq_len = None
|
||||
|
||||
def _get_cos_sin(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
|
||||
sin = sin.view(-1, self.rotary_dim)
|
||||
cos = cos.view(-1, self.rotary_dim)
|
||||
return cos, sin
|
||||
|
||||
def _get_positions_with_offsets_mlu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
offsets: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if offsets.numel() != positions.numel():
|
||||
raise Exception("rope offsets numel mismatch with positions, "
|
||||
f"positions: {positions.numel()}, offsets: {offsets.numel()}")
|
||||
return (positions + offsets).to(torch.int32)
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
common_metadata: MLUCommonAttentionMetadata = get_common_metadata()
|
||||
if common_metadata is None:
|
||||
num_tokens, head_num, head_size = x.shape
|
||||
x = mlu_ops.rotary_embedding(
|
||||
x.view(1, num_tokens, head_num, head_size),
|
||||
self.sin_,
|
||||
self.cos_,
|
||||
positions,
|
||||
None,
|
||||
not self.is_neox_style,
|
||||
True,
|
||||
False,
|
||||
num_tokens
|
||||
)
|
||||
return x
|
||||
else:
|
||||
cu_seq_lens_ = common_metadata.query_start_loc
|
||||
|
||||
if offsets is not None:
|
||||
if MLURotaryEmbedding.positions_ is None:
|
||||
MLURotaryEmbedding.positions_ = (
|
||||
self._get_positions_with_offsets_mlu(positions, offsets))
|
||||
position_ids = MLURotaryEmbedding.positions_
|
||||
discrete = True
|
||||
elif MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else:
|
||||
position_ids = None
|
||||
discrete = False
|
||||
|
||||
x = mlu_ops.rotary_embedding(
|
||||
x,
|
||||
self.sin_,
|
||||
self.cos_,
|
||||
position_ids,
|
||||
cu_seq_lens_,
|
||||
not self.is_neox_style,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len
|
||||
)
|
||||
return x
|
||||
|
||||
def get_param(self, positions, discrete=False):
|
||||
interleaved = True
|
||||
if self.is_neox_style:
|
||||
interleaved = False
|
||||
|
||||
if discrete:
|
||||
position_ids = positions
|
||||
discrete = discrete
|
||||
else:
|
||||
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else:
|
||||
position_ids = None
|
||||
discrete = False
|
||||
|
||||
return position_ids, interleaved, discrete
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
cos = freqs_cis.real
|
||||
sin = freqs_cis.imag * (-1 if self.inverse else 1)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor | None = None,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
only_prefill: bool | None = False,
|
||||
only_decode: bool | None = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.forward_impl(positions, query, offsets)
|
||||
if key is not None:
|
||||
self.forward_impl(positions, key, offsets)
|
||||
return query, key
|
||||
|
||||
|
||||
def rope_forward(
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
layer_name: str,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
self.forward_impl(positions, x, offsets)
|
||||
|
||||
|
||||
def rope_forward_fake(
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
layer_name: str,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rope_forward",
|
||||
op_func=rope_forward,
|
||||
mutates_args=["x"],
|
||||
fake_impl=rope_forward_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
Reference in New Issue
Block a user