280 lines
7.9 KiB
Python
280 lines
7.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
from importlib.util import find_spec
|
|
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
# common functions
|
|
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
|
x1 = x[..., ::2]
|
|
x2 = x[..., 1::2]
|
|
x = torch.stack((-x2, x1), dim=-1)
|
|
return x.flatten(-2)
|
|
|
|
|
|
# yarn functions
|
|
# Inverse dim formula to find dim based on number of rotations
|
|
def yarn_find_correction_dim(
|
|
num_rotations: int,
|
|
dim: int,
|
|
base: float = 10000,
|
|
max_position_embeddings: int = 2048,
|
|
) -> float:
|
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
|
2 * math.log(base)
|
|
)
|
|
|
|
|
|
# Find dim range bounds based on rotations
|
|
def yarn_find_correction_range(
|
|
low_rot: int,
|
|
high_rot: int,
|
|
dim: int,
|
|
base: float = 10000,
|
|
max_position_embeddings: int = 2048,
|
|
truncate: bool = True,
|
|
) -> tuple[float | int, float | int]:
|
|
low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
|
high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
|
if truncate:
|
|
low = math.floor(low)
|
|
high = math.ceil(high)
|
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
|
|
|
|
|
def yarn_linear_ramp_mask(
|
|
low: float, high: float, dim: int, dtype: torch.dtype
|
|
) -> torch.Tensor:
|
|
if low == high:
|
|
high += 0.001 # Prevent singularity
|
|
|
|
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
|
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
|
return ramp_func
|
|
|
|
|
|
def yarn_get_mscale(scale: float = 1) -> float:
|
|
if scale <= 1:
|
|
return 1.0
|
|
return 0.1 * math.log(scale) + 1.0
|
|
|
|
|
|
def _flashinfer_rotary_embedding(
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
head_size: int,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool,
|
|
) -> None:
|
|
"""Custom op wrapper for flashinfer's rotary embedding.
|
|
|
|
This is an in-place operation that modifies query and key tensors directly.
|
|
"""
|
|
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
|
|
|
|
apply_rope_with_cos_sin_cache_inplace(
|
|
positions=positions,
|
|
query=query,
|
|
key=key,
|
|
head_size=head_size,
|
|
cos_sin_cache=cos_sin_cache,
|
|
is_neox=is_neox,
|
|
)
|
|
|
|
|
|
def _flashinfer_rotary_embedding_fake(
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
head_size: int,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool,
|
|
) -> None:
|
|
return
|
|
|
|
|
|
# Register flashinfer rotary embedding custom op
|
|
direct_register_custom_op(
|
|
op_name="flashinfer_rotary_embedding",
|
|
op_func=_flashinfer_rotary_embedding,
|
|
mutates_args=["query", "key"], # These tensors are modified in-place
|
|
fake_impl=_flashinfer_rotary_embedding_fake,
|
|
)
|
|
|
|
|
|
@CustomOp.register("apply_rotary_emb")
|
|
class ApplyRotaryEmb(CustomOp):
|
|
def __init__(
|
|
self,
|
|
enforce_enable: bool = False,
|
|
is_neox_style: bool = True,
|
|
enable_fp32_compute: bool = False,
|
|
) -> None:
|
|
super().__init__(enforce_enable)
|
|
self.is_neox_style = is_neox_style
|
|
self.enable_fp32_compute = enable_fp32_compute
|
|
|
|
self.apply_rotary_emb_flash_attn = None
|
|
if find_spec("flash_attn") is not None:
|
|
from flash_attn.ops.triton.rotary import apply_rotary
|
|
|
|
self.apply_rotary_emb_flash_attn = apply_rotary
|
|
|
|
@staticmethod
|
|
def forward_static(
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
is_neox_style: bool = True,
|
|
enable_fp32_compute: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x: [batch_size (optional), seq_len, num_heads, head_size]
|
|
cos: [seq_len, head_size // 2]
|
|
sin: [seq_len, head_size // 2]
|
|
is_neox_style: Whether to use the Neox-style or GPT-J-style.
|
|
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
|
|
for higher accuracy.
|
|
"""
|
|
origin_dtype = x.dtype
|
|
if enable_fp32_compute:
|
|
x = x.float()
|
|
|
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
|
|
|
if is_neox_style:
|
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
|
else:
|
|
x1 = x[..., ::2]
|
|
x2 = x[..., 1::2]
|
|
|
|
o1 = x1 * cos - x2 * sin
|
|
o2 = x2 * cos + x1 * sin
|
|
|
|
if is_neox_style:
|
|
output = torch.cat((o1, o2), dim=-1)
|
|
else:
|
|
output = torch.stack((o1, o2), dim=-1).flatten(-2)
|
|
|
|
if enable_fp32_compute:
|
|
output = output.to(origin_dtype)
|
|
return output
|
|
|
|
def forward_native(
|
|
self,
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
output = self.forward_static(
|
|
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
|
|
)
|
|
return output
|
|
|
|
def forward_cuda(
|
|
self,
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
|
|
origin_dtype = x.dtype
|
|
if self.enable_fp32_compute:
|
|
x = x.float()
|
|
cos = cos.float()
|
|
sin = sin.float()
|
|
|
|
origin_shape = x.shape
|
|
if len(origin_shape) == 3:
|
|
# x: [seq_len, num_heads, head_size]
|
|
x = x.unsqueeze(0)
|
|
|
|
"""
|
|
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
|
x: [batch_size, seq_len, nheads, headdim]
|
|
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
|
interleaved: defalut as False (Neox-style).
|
|
...
|
|
"""
|
|
interleaved = not self.is_neox_style
|
|
output = apply_rotary_emb(x, cos, sin, interleaved)
|
|
|
|
if len(origin_shape) == 3:
|
|
output = output.squeeze(0)
|
|
if self.enable_fp32_compute:
|
|
output = output.to(origin_dtype)
|
|
return output
|
|
|
|
def forward_hip(
|
|
self,
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
if self.apply_rotary_emb_flash_attn is not None:
|
|
origin_dtype = x.dtype
|
|
if self.enable_fp32_compute:
|
|
x = x.float()
|
|
cos = cos.float()
|
|
sin = sin.float()
|
|
|
|
origin_shape = x.shape
|
|
if len(origin_shape) == 3:
|
|
# x: [seq_len, num_heads, head_size]
|
|
x = x.unsqueeze(0)
|
|
|
|
"""
|
|
Arguments of apply_rotary() in flash_attn:
|
|
x: [batch_size, seq_len, nheads, headdim]
|
|
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
|
interleaved: defalut as False (Neox-style).
|
|
...
|
|
"""
|
|
interleaved = not self.is_neox_style
|
|
output = self.apply_rotary_emb_flash_attn(
|
|
x, cos, sin, interleaved=interleaved
|
|
).type_as(x)
|
|
|
|
if len(origin_shape) == 3:
|
|
output = output.squeeze(0)
|
|
if self.enable_fp32_compute:
|
|
output = output.to(origin_dtype)
|
|
else:
|
|
# Falling back to PyTorch native implementation.
|
|
output = self.forward_native(x, cos, sin)
|
|
|
|
return output
|
|
|
|
def forward_cpu(
|
|
self,
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# TODO (bigPYJ1151): need to enable fused CPU ROPE here
|
|
return self.forward_native(x, cos, sin)
|
|
|
|
def extra_repr(self) -> str:
|
|
s = f"is_neox_style={self.is_neox_style}"
|
|
s += f"enable_fp32_compute={self.enable_fp32_compute}"
|
|
return s
|