Sync from v0.13
This commit is contained in:
160
vllm/model_executor/layers/rotary_embedding/xdrope.py
Normal file
160
vllm/model_executor/layers/rotary_embedding/xdrope.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
|
||||
|
||||
class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
||||
"""DynamicNTKAlphaRotaryEmbedding extended with MultiModal(XD) Sections.
|
||||
|
||||
Based on the original DynamicNTKAlphaRotaryEmbedding implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
scaling_alpha: float,
|
||||
dtype: torch.dtype,
|
||||
xdrope_section: list[int],
|
||||
) -> None:
|
||||
self.xdrope_section = xdrope_section
|
||||
super().__init__(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = torch.cat(
|
||||
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
sin = torch.cat(
|
||||
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = self.apply_rotary_emb.forward_native(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = self.apply_rotary_emb.forward_native(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = torch.cat(
|
||||
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
sin = torch.cat(
|
||||
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = self.apply_rotary_emb(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = self.apply_rotary_emb(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions(
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
xd_sections: int = 4,
|
||||
) -> list[list[int]]:
|
||||
return [list(range(context_len, seq_len)) for _ in range(xd_sections)]
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions_tensor(
|
||||
out: np.ndarray,
|
||||
out_offset: int,
|
||||
context_len: int,
|
||||
num_new_tokens: int,
|
||||
):
|
||||
values = np.arange(
|
||||
context_len,
|
||||
context_len + num_new_tokens,
|
||||
dtype=out.dtype,
|
||||
)
|
||||
out[:, out_offset : out_offset + num_new_tokens] = values
|
||||
Reference in New Issue
Block a user