Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -330,48 +330,46 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
|
||||
if self.mrope_interleaved:
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
q, k = triton_mrope(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_interleaved,
|
||||
)
|
||||
|
||||
q, k = triton_mrope(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_interleaved,
|
||||
)
|
||||
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = self.apply_rotary_emb(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = self.apply_rotary_emb(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
||||
if positions.ndim == 1:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
else:
|
||||
if self.is_neox_style:
|
||||
ops.m_rotary_embedding(positions.contiguous(), query, key, self.head_size,
|
||||
self.cos_sin_cache,
|
||||
torch.tensor(self.mrope_section, dtype=torch.int),
|
||||
self.is_neox_style)
|
||||
else:
|
||||
query, key = self.forward_native(
|
||||
positions, query, key
|
||||
)
|
||||
|
||||
|
||||
return query, key
|
||||
|
||||
def forward_cpu(
|
||||
|
||||
Reference in New Issue
Block a user