[main] Optimize rope in Qwen Models (#2571)

### What this PR does / why we need it?
Optimize rope by caching sin and cos at the first layer in Qwen Models.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.


- vLLM version: v0.10.1.1
- vLLM main:
562663a044

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: ZYang6263 <51255902183@stu.ecnu.edu.cn>
Co-authored-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
rjg-lyh
2025-09-09 14:28:14 +08:00
committed by GitHub
parent 5bcb4c1528
commit 7a205dbaa8
4 changed files with 136 additions and 47 deletions

View File

@@ -3,12 +3,18 @@ import unittest
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from transformers.configuration_utils import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
MODEL = "Qwen3-0.6B"
MAX_NUM_BATCHED_TOKEND = 10000
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
@@ -93,6 +99,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled, mock_is_310p,
mock__c):
@@ -102,9 +112,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock__c.rotary_embedding.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
@@ -113,6 +129,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
@@ -121,15 +141,25 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = self.layer.forward(self.positions,
non_contig_query,
non_contig_key)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions,
non_contig_query,
non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_with_offsets(self):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
@@ -137,22 +167,41 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
self.layer.forward(self.positions, self.query, self.key, offsets)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
self.layer.forward(self.positions, self.query, self.key,
offsets)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test neox_style override
result_q, result_k = self.layer.forward(self.positions,
self.query,
self.key,
is_neox_style_override=False)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(
self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
@@ -160,6 +209,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_rope_forward_oot_rotary_dim_less_than_head_size(
self, mock_npu_rotary, mock_custom_enabled):
mock_config = MagicMock()
@@ -169,8 +222,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
org_rotary_dim = self.layer.rotary_dim
self.layer.rotary_dim = self.layer.head_size // 2
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL,
tokenizer=MODEL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)

View File

@@ -119,6 +119,9 @@ def set_ascend_forward_context(
forward_context.flashcomm_v1_enabled = flashcomm_v1_enabled
# set this for rope forward_oot using
forward_context.is_first_layer = True
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens

View File

@@ -20,6 +20,7 @@
from typing import Optional, Union
import torch
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
@@ -280,6 +281,11 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache) may cause performance degradation at runtime.
x = torch.rand((2, 4), dtype=torch.float16).npu()
weight = torch.rand((2, 4), dtype=torch.float16).npu()
c = torch.rand((4, 4), dtype=torch.float32).npu()
torch_npu._npu_matmul_add_fp32(x, weight, c)
def forward(
self,

View File

@@ -20,6 +20,7 @@ from typing import Optional, Tuple
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
@@ -37,19 +38,16 @@ def _rope_forward_oot(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
is_neox_style: bool,
offsets: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
neox_style = self.is_neox_style
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled(query, neox_style,
if _custom_rotary_embedding_enabled(query, is_neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
positions,
@@ -57,14 +55,22 @@ def _rope_forward_oot(
key,
self.head_size,
self.cos_sin_cache,
neox_style,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
if self.rotary_dim < self.head_size:
if self.cos is not None and \
self.sin is not None:
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1,
self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin)
elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
@@ -80,25 +86,26 @@ def _rope_forward_oot(
k_rot,
self.head_size,
self.cos_sin_cache,
neox_style,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
class AscendRotaryEmbedding(RotaryEmbedding):
@@ -112,6 +119,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
self.cos = None
self.sin = None
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
@@ -123,14 +132,25 @@ class AscendRotaryEmbedding(RotaryEmbedding):
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
):
return _rope_forward_oot(
self,
positions,
query,
key,
offsets,
is_neox_style_override,
)
is_neox_style = self.is_neox_style
if is_neox_style_override is not None:
is_neox_style = is_neox_style_override
forward_context = get_forward_context()
is_first_layer = forward_context.is_first_layer
# Generate cos and sin outside layers to avoid repeated calculation.
if is_neox_style and \
self.head_size == 128:
if is_first_layer:
cos_sin = self.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
1, 1, 2).chunk(2, dim=-2)
# BSNH
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
forward_context.is_first_layer = False
return _rope_forward_oot(self, positions, query, key, is_neox_style,
offsets)
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
@@ -322,7 +342,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
is_neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2,
@@ -330,6 +350,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3,
2).reshape(b, h_k, d)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
neox_style)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
is_neox_style, offsets)
return q_pe, k_pe