diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index 3a796ae..21d95bb 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -6,14 +6,13 @@ import torch from transformers.configuration_utils import PretrainedConfig from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding) + DeepseekScalingRotaryEmbedding, RotaryEmbedding) from tests.ut.base import TestBase from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled MODEL = "Qwen3-0.6B" -MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct" MAX_NUM_BATCHED_TOKEND = 10000 @@ -377,86 +376,3 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase): expected, places=6, msg=f"Failed for scale={scale}, mscale={mscale}") - - -class TestAscendMRotaryEmbedding(unittest.TestCase): - - def setUp(self): - # Common setup for tests - self.number_tokens = 3 - self.num_head = 8 - self.num_kvhead = 8 - self.head_size = 128 - self.max_position_embeddings = 128000 - self.is_neox_style = True - self.rope_theta = 1000000.0 - self.positions_1d = torch.tensor([1, 2, 3]) - self.positions_2d = torch.randint(1, 10, (3, self.number_tokens)) - - self.query = torch.randn( - (self.number_tokens, self.num_head * self.head_size), - dtype=torch.bfloat16) - self.key = torch.randn( - (self.number_tokens, self.num_kvhead * self.head_size), - dtype=torch.bfloat16) - - # Qwen2.5-VL mrope section case - self.mrope_section = [16, 24, 24] - - self.layer = MRotaryEmbedding(self.head_size, - self.head_size, - self.max_position_embeddings, - base=self.rope_theta, - is_neox_style=self.is_neox_style, - dtype=torch.bfloat16, - mrope_section=self.mrope_section) - - self.mock_config = MagicMock() - self.mock_config.torchair_graph_config.enabled = False - - def _create_vllm_config(self): - vllm_config = VllmConfig() - model_config = ModelConfig(MODEL_VL, - tokenizer=MODEL_VL, - max_model_len=MAX_NUM_BATCHED_TOKEND) - model_config.hf_config = PretrainedConfig() - vllm_config.model_config = model_config - return vllm_config - - @patch('torch_npu.npu_mrope') - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_forward_oot_1d_positions(self, mock_npu_mrope): - mock_npu_mrope.return_value = (torch.zeros_like(self.query), - torch.zeros_like(self.key)) - - vllm_config = self._create_vllm_config() - with set_ascend_forward_context(None, vllm_config): - result_q, result_k = self.layer.forward_oot( - self.positions_1d, self.query, self.key) - - mock_npu_mrope.assert_called_once() - self.assertFalse(torch.isnan(result_q).any().item()) - self.assertFalse(torch.isnan(result_k).any().item()) - self.assertEqual(result_q.shape, self.query.shape) - - @patch('torch_npu.npu_mrope') - @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) - @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) - @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) - @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) - def test_forward_oot_2d_positions(self, mock_npu_mrope): - mock_npu_mrope.return_value = (torch.zeros_like(self.query), - torch.zeros_like(self.key)) - - vllm_config = self._create_vllm_config() - with set_ascend_forward_context(None, vllm_config): - result_q, result_k = self.layer.forward_oot( - self.positions_2d, self.query, self.key) - - mock_npu_mrope.assert_called_once() - self.assertFalse(torch.isnan(result_q).any().item()) - self.assertFalse(torch.isnan(result_k).any().item()) - self.assertEqual(result_q.shape, self.query.shape) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index fddc523..69102f3 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -22,7 +22,7 @@ import torch import torch_npu from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, + DeepseekScalingRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) from vllm_ascend.platform import NPUPlatform @@ -395,37 +395,3 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): q_pe, k_pe = _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) return q_pe, k_pe - - -class AscendMRotaryEmbedding(MRotaryEmbedding): - - def forward_oot( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - ): - if self.mrope_section != [16, 24, 24]: - return super().forward_oot(positions, query, key) - - import torch_npu - mrope_section = [0, 0, 0 - ] if positions.ndim == 1 else self.mrope_section - - if self.cos_sin_cache.device != query.device: # type: ignore - self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore - query.device) # type: ignore - - if self.cos_sin_cache.dtype != query.dtype: # type: ignore - self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore - query.dtype) # type: ignore - - query, key = torch_npu.npu_mrope(positions, - query.contiguous(), - key.contiguous(), - self.cos_sin_cache.contiguous(), - self.head_size, - mrope_section=mrope_section, - rotary_mode='half') - - return query, key diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f824662..0929e40 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -517,8 +517,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): AscendReplicatedLinear, AscendRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( - AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding, - AscendRotaryEmbedding, AscendYaRNRotaryEmbedding) + AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding, + AscendYaRNRotaryEmbedding) from vllm_ascend.ops.vocab_parallel_embedding import ( AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) @@ -528,7 +528,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "QuickGELU": AscendQuickGELU, "SiluAndMul": AscendSiluAndMul, "RotaryEmbedding": AscendRotaryEmbedding, - "MRotaryEmbedding": AscendMRotaryEmbedding, "ColumnParallelLinear": AscendColumnParallelLinear, "RowParallelLinear": AscendRowParallelLinear, "YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,