Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -227,6 +227,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
self.head_size,
|
||||
cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
self.rotary_dim,
|
||||
)
|
||||
return query, key
|
||||
|
||||
|
||||
@@ -229,22 +229,7 @@ class ApplyRotaryEmb(CustomOp):
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
return self.forward_native(x, cos, sin)
|
||||
x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
|
||||
|
||||
"""
|
||||
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
||||
x: [batch_size, seq_len, nheads, headdim]
|
||||
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||
interleaved: defalut as False (Neox-style).
|
||||
...
|
||||
"""
|
||||
interleaved = not self.is_neox_style
|
||||
output = apply_rotary_emb(x, cos, sin, interleaved)
|
||||
|
||||
output = self._post_process(output, origin_shape, origin_dtype)
|
||||
return output
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
from .base import RotaryEmbeddingBase
|
||||
from .base import RotaryEmbedding
|
||||
from .common import (
|
||||
rotate_gptj,
|
||||
rotate_neox,
|
||||
@@ -23,7 +23,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with YaRN method.
|
||||
|
||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||
@@ -110,73 +110,3 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
sin = freqs.sin() * self.mscale
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
|
||||
cos_sin = cos_sin_cache[
|
||||
torch.add(positions, offsets) if offsets is not None else positions
|
||||
]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# shape [batch_size, seq_len].
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
|
||||
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
|
||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
else:
|
||||
query = query_rot
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if self.use_flashinfer:
|
||||
torch.ops.vllm.flashinfer_rotary_embedding(
|
||||
torch.add(positions, offsets) if offsets is not None else positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return query, key
|
||||
else:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
@@ -330,48 +330,46 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
|
||||
if self.mrope_interleaved:
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
q, k = triton_mrope(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_interleaved,
|
||||
)
|
||||
|
||||
q, k = triton_mrope(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_interleaved,
|
||||
)
|
||||
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = self.apply_rotary_emb(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = self.apply_rotary_emb(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
||||
if positions.ndim == 1:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
else:
|
||||
if self.is_neox_style:
|
||||
ops.m_rotary_embedding(positions.contiguous(), query, key, self.head_size,
|
||||
self.cos_sin_cache,
|
||||
torch.tensor(self.mrope_section, dtype=torch.int),
|
||||
self.is_neox_style)
|
||||
else:
|
||||
query, key = self.forward_native(
|
||||
positions, query, key
|
||||
)
|
||||
|
||||
|
||||
return query, key
|
||||
|
||||
def forward_cpu(
|
||||
|
||||
@@ -12,6 +12,7 @@ from .common import rotate_neox
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
import ixformer.inference.functions as ixops
|
||||
|
||||
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
"""Phi3 family of models scaled rotary embedding.
|
||||
@@ -133,27 +134,18 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
|
||||
if self.use_long_rope:
|
||||
k = self.original_max_position_embeddings
|
||||
long_prompt_offset = torch.full_like(positions, k).long()
|
||||
idx = torch.add(positions, long_prompt_offset)
|
||||
else:
|
||||
idx = positions
|
||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||
k = self.original_max_position_embeddings
|
||||
long_prompt_offset = torch.any(positions > k)
|
||||
|
||||
ixops.vllm_rotary_embedding_phi(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.long_short_cos_sin_cache,
|
||||
long_prompt_offset,
|
||||
k,
|
||||
offsets
|
||||
)
|
||||
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = cos.repeat(1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 2).unsqueeze(-2)
|
||||
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = query_rot * cos + rotate_neox(query_rot) * sin
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = key_rot * cos + rotate_neox(key_rot) * sin
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
return query.flatten(-2), key.flatten(-2)
|
||||
return query, key
|
||||
|
||||
Reference in New Issue
Block a user