There is a lot hack code for v0.11.0, which makes the code hard to
upgrade to newer vLLM version. Since v0.11.0 will release soon. Let's
drop v0.11.0 support first. Then we'll upgrade to v0.11.2 soon.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
161 lines
6.6 KiB
Python
161 lines
6.6 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from torch import nn
|
|
from vllm.config import CacheConfig, CompilationConfig, VllmConfig
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.model_executor.layers.mla import MLAModules
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.models.layers.mla import (AscendMultiHeadLatentAttention,
|
|
IndexerWrapper)
|
|
|
|
|
|
class TestIndexerWrapper(TestBase):
|
|
|
|
def test_initialization(self):
|
|
mock_indexer = MagicMock()
|
|
mock_indexer.n_head = 64
|
|
mock_indexer.head_dim = 128
|
|
mock_indexer.topk_tokens = 2048
|
|
mock_indexer.q_lora_rank = 1536
|
|
mock_indexer.wq_b = nn.Linear(128, 128)
|
|
mock_indexer.wk = nn.Linear(128, 128)
|
|
mock_indexer.weights_proj = nn.Linear(128, 128)
|
|
mock_indexer.k_norm = nn.LayerNorm(128)
|
|
mock_indexer.softmax_scale = 0.123
|
|
mock_indexer.topk_indices_buffer = torch.randn(10)
|
|
mock_indexer.k_cache = torch.randn(10)
|
|
|
|
wrapper = IndexerWrapper(mock_indexer)
|
|
|
|
self.assertEqual(wrapper.n_head, 64)
|
|
self.assertEqual(wrapper.head_dim, 128)
|
|
self.assertEqual(wrapper.topk_tokens, 2048)
|
|
self.assertEqual(wrapper.q_lora_rank, 1536)
|
|
self.assertIs(wrapper.wq_b, mock_indexer.wq_b)
|
|
self.assertIs(wrapper.wk, mock_indexer.wk)
|
|
self.assertIs(wrapper.weights_proj, mock_indexer.weights_proj)
|
|
self.assertIs(wrapper.k_norm, mock_indexer.k_norm)
|
|
self.assertEqual(wrapper.softmax_scale, 0.123)
|
|
|
|
self.assertIsNone(mock_indexer.topk_indices_buffer)
|
|
self.assertIsNone(mock_indexer.k_cache)
|
|
|
|
def test_forward(self):
|
|
mock_indexer = MagicMock()
|
|
wrapper = IndexerWrapper(mock_indexer)
|
|
result = wrapper.forward()
|
|
self.assertIsNone(result)
|
|
|
|
|
|
class TestAscendMultiHeadLatentAttention(TestBase):
|
|
|
|
def setUp(self):
|
|
self.hidden_size = 4096
|
|
self.num_heads = 32
|
|
self.scale = 0.123
|
|
self.qk_nope_head_dim = 64
|
|
self.qk_rope_head_dim = 64
|
|
self.v_head_dim = 128
|
|
self.q_lora_rank = 1536
|
|
self.kv_lora_rank = 128
|
|
self.prefix = "model.layers.0.mla"
|
|
|
|
self.mock_mla_modules = MagicMock(spec=MLAModules)
|
|
self.mock_mla_modules.indexer = MagicMock()
|
|
self.mock_mla_modules.is_sparse = False
|
|
self.mock_mla_modules.rotary_emb = MagicMock()
|
|
self.mock_mla_modules.fused_qkv_a_proj = MagicMock()
|
|
self.mock_mla_modules.q_b_proj = MagicMock()
|
|
self.mock_mla_modules.q_a_layernorm = MagicMock()
|
|
self.mock_mla_modules.q_proj = MagicMock()
|
|
self.mock_mla_modules.kv_a_proj_with_mqa = MagicMock()
|
|
self.mock_mla_modules.kv_a_layernorm = MagicMock()
|
|
self.mock_mla_modules.kv_b_proj = MagicMock()
|
|
self.mock_mla_modules.o_proj = MagicMock()
|
|
|
|
self.mock_cache_config = MagicMock(spec=CacheConfig)
|
|
self.mock_quant_config = MagicMock()
|
|
|
|
@patch("vllm_ascend.models.layers.mla.get_current_vllm_config")
|
|
@patch("vllm_ascend.models.layers.mla.get_ascend_config")
|
|
@patch(
|
|
"vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size")
|
|
def test_initialization(self, mock_tp_size, mock_ascend_config,
|
|
mock_get_vllm_config):
|
|
|
|
with patch("vllm_ascend.models.layers.mla.MLAAttention",
|
|
return_value=True):
|
|
mock_tp_size.return_value = 2
|
|
mock_ascend_config.return_value.enable_shared_expert_dp = True
|
|
mock_vllm_config = MagicMock(spec=VllmConfig)
|
|
mock_vllm_config.model_config.hf_config = MagicMock(
|
|
num_hidden_layers=32, first_k_dense_replace=True)
|
|
mock_get_vllm_config.return_value = mock_vllm_config
|
|
mock_vllm_config.compilation_config = CompilationConfig()
|
|
|
|
attn = AscendMultiHeadLatentAttention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=self.num_heads,
|
|
scale=self.scale,
|
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
v_head_dim=self.v_head_dim,
|
|
q_lora_rank=self.q_lora_rank,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
mla_modules=self.mock_mla_modules,
|
|
cache_config=self.mock_cache_config,
|
|
quant_config=self.mock_quant_config,
|
|
prefix=self.prefix,
|
|
)
|
|
|
|
self.assertEqual(attn.tp_size, 2)
|
|
self.assertTrue(attn.enable_shared_expert_dp)
|
|
self.assertIsNotNone(attn.mla_attn)
|
|
|
|
@patch("vllm_ascend.models.layers.mla.torch.ops.vllm.mla_forward")
|
|
@patch("vllm_ascend.models.layers.mla.get_current_vllm_config")
|
|
@patch("vllm_ascend.models.layers.mla.get_ascend_config")
|
|
@patch(
|
|
"vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size")
|
|
@patch("vllm_ascend.models.layers.mla.get_forward_context")
|
|
def test_forward(self, mock_get_forward_context, mock_tp_size,
|
|
mock_ascend_config, mock_get_vllm_config,
|
|
mock_mla_forward):
|
|
mock_tp_size.return_value = 1
|
|
mock_ascend_config.return_value.enable_shared_expert_dp = False
|
|
mock_vllm_config = MagicMock(spec=VllmConfig)
|
|
mock_vllm_config.model_config.hf_config = MagicMock(
|
|
num_hidden_layers=32, first_k_dense_replace=False)
|
|
mock_get_vllm_config.return_value = mock_vllm_config
|
|
mock_vllm_config.compilation_config = CompilationConfig()
|
|
with patch("vllm_ascend.models.layers.mla.MLAAttention",
|
|
return_value=True):
|
|
attn = AscendMultiHeadLatentAttention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=self.num_heads,
|
|
scale=self.scale,
|
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
v_head_dim=self.v_head_dim,
|
|
q_lora_rank=self.q_lora_rank,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
mla_modules=self.mock_mla_modules,
|
|
cache_config=self.mock_cache_config,
|
|
quant_config=self.mock_quant_config,
|
|
prefix=self.prefix,
|
|
)
|
|
positions = torch.tensor([0, 1, 2])
|
|
hidden_states = torch.randn(3, self.hidden_size)
|
|
|
|
mock_forward_context = MagicMock(spec=ForwardContext)
|
|
mock_forward_context.sp_enabled = False
|
|
mock_get_forward_context.return_value = mock_forward_context
|
|
|
|
mock_mla_forward.return_value = (3, self.hidden_size)
|
|
|
|
output = attn.forward(positions, hidden_states)
|
|
|
|
self.assertEqual(output.shape, (3, self.hidden_size))
|