Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -330,48 +330,46 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
from vllm import _custom_ops as ops
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
num_tokens = positions.shape[-1]
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
if positions.ndim == 2:
assert self.mrope_section
self._match_cos_sin_cache_dtype(query)
if self.mrope_interleaved:
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
if positions.ndim == 2:
assert self.mrope_section
q, k = triton_mrope(
query,
key,
cos,
sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
q, k = triton_mrope(
query,
key,
cos,
sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
return q.reshape(query_shape), k.reshape(key_shape)
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return q.reshape(query_shape), k.reshape(key_shape)
if positions.ndim == 1:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
else:
if self.is_neox_style:
ops.m_rotary_embedding(positions.contiguous(), query, key, self.head_size,
self.cos_sin_cache,
torch.tensor(self.mrope_section, dtype=torch.int),
self.is_neox_style)
else:
query, key = self.forward_native(
positions, query, key
)
return query, key
def forward_cpu(