Move mla to ops module (#4575)
Move mla custom op to correct module - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -7,8 +7,7 @@ from vllm.forward_context import ForwardContext
|
||||
from vllm.model_executor.layers.mla import MLAModules
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.models.layers.mla import (AscendMultiHeadLatentAttention,
|
||||
IndexerWrapper)
|
||||
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention, IndexerWrapper
|
||||
|
||||
|
||||
class TestIndexerWrapper(TestBase):
|
||||
@@ -78,15 +77,13 @@ class TestAscendMultiHeadLatentAttention(TestBase):
|
||||
self.mock_cache_config = MagicMock(spec=CacheConfig)
|
||||
self.mock_quant_config = MagicMock()
|
||||
|
||||
@patch("vllm_ascend.models.layers.mla.get_current_vllm_config")
|
||||
@patch("vllm_ascend.models.layers.mla.get_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size")
|
||||
@patch("vllm_ascend.ops.mla.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.mla.get_ascend_config")
|
||||
@patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size")
|
||||
def test_initialization(self, mock_tp_size, mock_ascend_config,
|
||||
mock_get_vllm_config):
|
||||
|
||||
with patch("vllm_ascend.models.layers.mla.MLAAttention",
|
||||
return_value=True):
|
||||
with patch("vllm_ascend.ops.mla.MLAAttention", return_value=True):
|
||||
mock_tp_size.return_value = 2
|
||||
mock_ascend_config.return_value.enable_shared_expert_dp = True
|
||||
mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||
@@ -114,12 +111,11 @@ class TestAscendMultiHeadLatentAttention(TestBase):
|
||||
self.assertTrue(attn.enable_shared_expert_dp)
|
||||
self.assertIsNotNone(attn.mla_attn)
|
||||
|
||||
@patch("vllm_ascend.models.layers.mla.torch.ops.vllm.mla_forward")
|
||||
@patch("vllm_ascend.models.layers.mla.get_current_vllm_config")
|
||||
@patch("vllm_ascend.models.layers.mla.get_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size")
|
||||
@patch("vllm_ascend.models.layers.mla.get_forward_context")
|
||||
@patch("vllm_ascend.ops.mla.torch.ops.vllm.mla_forward")
|
||||
@patch("vllm_ascend.ops.mla.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.mla.get_ascend_config")
|
||||
@patch("vllm_ascend.ops.mla.get_tensor_model_parallel_world_size")
|
||||
@patch("vllm_ascend.ops.mla.get_forward_context")
|
||||
def test_forward(self, mock_get_forward_context, mock_tp_size,
|
||||
mock_ascend_config, mock_get_vllm_config,
|
||||
mock_mla_forward):
|
||||
@@ -130,8 +126,7 @@ class TestAscendMultiHeadLatentAttention(TestBase):
|
||||
num_hidden_layers=32, first_k_dense_replace=False)
|
||||
mock_get_vllm_config.return_value = mock_vllm_config
|
||||
mock_vllm_config.compilation_config = CompilationConfig()
|
||||
with patch("vllm_ascend.models.layers.mla.MLAAttention",
|
||||
return_value=True):
|
||||
with patch("vllm_ascend.ops.mla.MLAAttention", return_value=True):
|
||||
attn = AscendMultiHeadLatentAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
|
||||
@@ -648,7 +648,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
return
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
|
||||
AscendSharedFusedMoE)
|
||||
@@ -658,6 +657,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
AscendQKVParallelLinear,
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
|
||||
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
|
||||
|
||||
Reference in New Issue
Block a user