Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -227,6 +227,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
self.head_size,
cos_sin_cache,
self.is_neox_style,
self.rotary_dim,
)
return query, key

View File

@@ -229,22 +229,7 @@ class ApplyRotaryEmb(CustomOp):
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
# from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
return self.forward_native(x, cos, sin)
x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
"""
Arguments of apply_rotary_emb() in vllm_flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = apply_rotary_emb(x, cos, sin, interleaved)
output = self._post_process(output, origin_shape, origin_dtype)
return output
def forward_hip(
self,

View File

@@ -8,7 +8,7 @@ import torch
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from .base import RotaryEmbeddingBase
from .base import RotaryEmbedding
from .common import (
rotate_gptj,
rotate_neox,
@@ -23,7 +23,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
@@ -110,73 +110,3 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
cos_sin = cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.use_flashinfer:
torch.ops.vllm.flashinfer_rotary_embedding(
torch.add(positions, offsets) if offsets is not None else positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
else:
return self.forward_native(positions, query, key, offsets)

View File

@@ -330,48 +330,46 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
from vllm import _custom_ops as ops
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
num_tokens = positions.shape[-1]
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
if positions.ndim == 2:
assert self.mrope_section
self._match_cos_sin_cache_dtype(query)
if self.mrope_interleaved:
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
if positions.ndim == 2:
assert self.mrope_section
q, k = triton_mrope(
query,
key,
cos,
sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
q, k = triton_mrope(
query,
key,
cos,
sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
return q.reshape(query_shape), k.reshape(key_shape)
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return q.reshape(query_shape), k.reshape(key_shape)
if positions.ndim == 1:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
else:
if self.is_neox_style:
ops.m_rotary_embedding(positions.contiguous(), query, key, self.head_size,
self.cos_sin_cache,
torch.tensor(self.mrope_section, dtype=torch.int),
self.is_neox_style)
else:
query, key = self.forward_native(
positions, query, key
)
return query, key
def forward_cpu(

View File

@@ -12,6 +12,7 @@ from .common import rotate_neox
logger = init_logger(__name__)
import ixformer.inference.functions as ixops
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
@@ -133,27 +134,18 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
if self.use_long_rope:
k = self.original_max_position_embeddings
long_prompt_offset = torch.full_like(positions, k).long()
idx = torch.add(positions, long_prompt_offset)
else:
idx = positions
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
k = self.original_max_position_embeddings
long_prompt_offset = torch.any(positions > k)
ixops.vllm_rotary_embedding_phi(
positions,
query,
key,
self.head_size,
self.long_short_cos_sin_cache,
long_prompt_offset,
k,
offsets
)
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = query_rot * cos + rotate_neox(query_rot) * sin
query = torch.cat((query_rot, query_pass), dim=-1)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = key_rot * cos + rotate_neox(key_rot) * sin
key = torch.cat((key_rot, key_pass), dim=-1)
return query.flatten(-2), key.flatten(-2)
return query, key