[main] Optimize rope in Qwen Models (#2571)

### What this PR does / why we need it?
Optimize rope by caching sin and cos at the first layer in Qwen Models.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.10.1.1
- vLLM main:
562663a044

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: ZYang6263 <51255902183@stu.ecnu.edu.cn>
Co-authored-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
rjg-lyh
2025-09-09 14:28:14 +08:00
committed by GitHub
parent 5bcb4c1528
commit 7a205dbaa8
4 changed files with 136 additions and 47 deletions

View File

@@ -20,6 +20,7 @@ from typing import Optional, Tuple
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
@@ -37,19 +38,16 @@ def _rope_forward_oot(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
is_neox_style: bool,
offsets: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
neox_style = self.is_neox_style
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled(query, neox_style,
if _custom_rotary_embedding_enabled(query, is_neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
positions,
@@ -57,14 +55,22 @@ def _rope_forward_oot(
key,
self.head_size,
self.cos_sin_cache,
neox_style,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
if self.rotary_dim < self.head_size:
if self.cos is not None and \
self.sin is not None:
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1,
self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin)
elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
@@ -80,25 +86,26 @@ def _rope_forward_oot(
k_rot,
self.head_size,
self.cos_sin_cache,
neox_style,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
class AscendRotaryEmbedding(RotaryEmbedding):
@@ -112,6 +119,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
self.cos = None
self.sin = None
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
@@ -123,14 +132,25 @@ class AscendRotaryEmbedding(RotaryEmbedding):
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
):
return _rope_forward_oot(
self,
positions,
query,
key,
offsets,
is_neox_style_override,
)
is_neox_style = self.is_neox_style
if is_neox_style_override is not None:
is_neox_style = is_neox_style_override
forward_context = get_forward_context()
is_first_layer = forward_context.is_first_layer
# Generate cos and sin outside layers to avoid repeated calculation.
if is_neox_style and \
self.head_size == 128:
if is_first_layer:
cos_sin = self.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
1, 1, 2).chunk(2, dim=-2)
# BSNH
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
forward_context.is_first_layer = False
return _rope_forward_oot(self, positions, query, key, is_neox_style,
offsets)
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
@@ -322,7 +342,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
is_neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2,
@@ -330,6 +350,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3,
2).reshape(b, h_k, d)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
neox_style)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
is_neox_style, offsets)
return q_pe, k_pe