[CI] Add mla ut (#4280)
### What this PR does / why we need it?
add mla_v1.py and mla.py ut
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
`pytest tests/ut/attention/test_mla_v1.py`
`pytest tests/ut/models/test_mla.py`
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
@@ -11,6 +12,7 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
|
||||
AscendMLAImpl, AscendMLAMetadata,
|
||||
AscendMLAMetadataBuilder,
|
||||
AscendMLAPrefillMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
|
||||
|
||||
class TestAscendMLABackend(TestBase):
|
||||
@@ -306,6 +308,264 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
input_batch.swap_states.assert_called_once_with(1, 2)
|
||||
|
||||
|
||||
class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||
self.mock_vllm_config.model_config = ModelConfig(max_model_len=2048)
|
||||
self.mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 32
|
||||
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
|
||||
self.mock_vllm_config.scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=8, chunked_prefill_enabled=True)
|
||||
self.mock_vllm_config.speculative_config = None
|
||||
self.mock_device = torch.device("cpu")
|
||||
|
||||
self.kv_cache_spec = MagicMock()
|
||||
self.kv_cache_spec.num_layers = 32
|
||||
self.kv_cache_spec.head_size = 128
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_prefix_no_cache_metadata(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 3, 7]),
|
||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
||||
seq_lens_cpu=torch.tensor([5, 6]),
|
||||
num_reqs=2,
|
||||
num_actual_tokens=10,
|
||||
max_query_len=5,
|
||||
decode_token_per_req=torch.tensor([1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
base_inputs = {
|
||||
"num_actual_tokens": 10,
|
||||
"slot_mapping": torch.tensor(range(10)),
|
||||
"query_start_loc": torch.tensor([0, 3, 7]),
|
||||
"seq_lens": torch.tensor([5, 6]),
|
||||
"block_tables": torch.zeros((10, 10)),
|
||||
"num_prefills": 2,
|
||||
}
|
||||
|
||||
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
base_inputs["num_actual_tokens"])
|
||||
self.assertTrue(
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_chunked_prefix_metadata(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 2, 5, 9]),
|
||||
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
|
||||
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=15,
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
base_inputs = {
|
||||
"num_actual_tokens": 15,
|
||||
"slot_mapping": torch.tensor(range(15)),
|
||||
"query_start_loc": torch.tensor([0, 2, 5, 9]),
|
||||
"seq_lens": torch.tensor([4, 5, 6]),
|
||||
"block_tables": torch.zeros((10, 10)),
|
||||
"num_prefills": 3,
|
||||
}
|
||||
|
||||
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
base_inputs["num_actual_tokens"])
|
||||
self.assertTrue(
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_decode_only_metadata(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=3,
|
||||
max_query_len=1,
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(3)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((3, 3)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
base_inputs = {
|
||||
"num_actual_tokens": 3,
|
||||
"slot_mapping": torch.tensor(range(3)),
|
||||
"query_start_loc": torch.tensor([0, 1, 2, 3]),
|
||||
"seq_lens": torch.tensor([4, 5, 6]),
|
||||
"num_decodes": 3,
|
||||
}
|
||||
|
||||
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
base_inputs["num_actual_tokens"])
|
||||
self.assertTrue(
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=3,
|
||||
max_query_len=1,
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(3)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((3, 3)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
base_inputs = {
|
||||
"num_actual_tokens": 3,
|
||||
"slot_mapping": torch.tensor(range(3)),
|
||||
"query_start_loc": torch.tensor([0, 1, 2, 3]),
|
||||
"seq_lens": torch.tensor([4, 5, 6]),
|
||||
"num_decodes": 3,
|
||||
}
|
||||
|
||||
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
metadata = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLAMetadata)
|
||||
self.assertEqual(metadata.num_actual_tokens,
|
||||
base_inputs["num_actual_tokens"])
|
||||
self.assertTrue(
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_for_graph_capture_prefill(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 3, 7]),
|
||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
||||
seq_lens_cpu=torch.tensor([5, 6]),
|
||||
num_reqs=2,
|
||||
num_actual_tokens=10,
|
||||
max_query_len=5,
|
||||
decode_token_per_req=torch.tensor([1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
|
||||
layer_names=["layer_0", "layer_1"],
|
||||
vllm_config=self.mock_vllm_config,
|
||||
device=self.mock_device)
|
||||
|
||||
mock_model = MagicMock()
|
||||
|
||||
with self.assertRaises(NotImplementedError) as ctx:
|
||||
builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.PrefillNoCache,
|
||||
mock_model)
|
||||
self.assertIn(
|
||||
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state",
|
||||
str(ctx.exception))
|
||||
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
|
||||
215
tests/ut/models/test_mla.py
Normal file
215
tests/ut/models/test_mla.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.config import CacheConfig, CompilationConfig, VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.model_executor.layers.mla import MLAModules
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.models.layers.mla import (AscendMultiHeadLatentAttention,
|
||||
IndexerWrapper)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class TestIndexerWrapper(TestBase):
|
||||
|
||||
def test_initialization(self):
|
||||
mock_indexer = MagicMock()
|
||||
mock_indexer.n_head = 64
|
||||
mock_indexer.head_dim = 128
|
||||
mock_indexer.topk_tokens = 2048
|
||||
mock_indexer.q_lora_rank = 1536
|
||||
mock_indexer.wq_b = nn.Linear(128, 128)
|
||||
mock_indexer.wk = nn.Linear(128, 128)
|
||||
mock_indexer.weights_proj = nn.Linear(128, 128)
|
||||
mock_indexer.k_norm = nn.LayerNorm(128)
|
||||
mock_indexer.softmax_scale = 0.123
|
||||
mock_indexer.topk_indices_buffer = torch.randn(10)
|
||||
mock_indexer.k_cache = torch.randn(10)
|
||||
|
||||
wrapper = IndexerWrapper(mock_indexer)
|
||||
|
||||
self.assertEqual(wrapper.n_head, 64)
|
||||
self.assertEqual(wrapper.head_dim, 128)
|
||||
self.assertEqual(wrapper.topk_tokens, 2048)
|
||||
self.assertEqual(wrapper.q_lora_rank, 1536)
|
||||
self.assertIs(wrapper.wq_b, mock_indexer.wq_b)
|
||||
self.assertIs(wrapper.wk, mock_indexer.wk)
|
||||
self.assertIs(wrapper.weights_proj, mock_indexer.weights_proj)
|
||||
self.assertIs(wrapper.k_norm, mock_indexer.k_norm)
|
||||
self.assertEqual(wrapper.softmax_scale, 0.123)
|
||||
|
||||
self.assertIsNone(mock_indexer.topk_indices_buffer)
|
||||
self.assertIsNone(mock_indexer.k_cache)
|
||||
|
||||
def test_forward(self):
|
||||
mock_indexer = MagicMock()
|
||||
wrapper = IndexerWrapper(mock_indexer)
|
||||
result = wrapper.forward()
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestAscendMultiHeadLatentAttention(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.hidden_size = 4096
|
||||
self.num_heads = 32
|
||||
self.scale = 0.123
|
||||
self.qk_nope_head_dim = 64
|
||||
self.qk_rope_head_dim = 64
|
||||
self.v_head_dim = 128
|
||||
self.q_lora_rank = 1536
|
||||
self.kv_lora_rank = 128
|
||||
self.prefix = "model.layers.0.mla"
|
||||
|
||||
self.mock_mla_modules = MagicMock(spec=MLAModules)
|
||||
self.mock_mla_modules.indexer = MagicMock()
|
||||
self.mock_mla_modules.is_sparse = False
|
||||
self.mock_mla_modules.rotary_emb = MagicMock()
|
||||
self.mock_mla_modules.fused_qkv_a_proj = MagicMock()
|
||||
self.mock_mla_modules.q_b_proj = MagicMock()
|
||||
self.mock_mla_modules.q_a_layernorm = MagicMock()
|
||||
self.mock_mla_modules.q_proj = MagicMock()
|
||||
self.mock_mla_modules.kv_a_proj_with_mqa = MagicMock()
|
||||
self.mock_mla_modules.kv_a_layernorm = MagicMock()
|
||||
self.mock_mla_modules.kv_b_proj = MagicMock()
|
||||
self.mock_mla_modules.o_proj = MagicMock()
|
||||
|
||||
self.mock_cache_config = MagicMock(spec=CacheConfig)
|
||||
self.mock_quant_config = MagicMock()
|
||||
|
||||
@patch("vllm_ascend.models.layers.mla.get_current_vllm_config")
|
||||
@patch("vllm_ascend.models.layers.mla.get_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size")
|
||||
def test_initialization(self, mock_tp_size, mock_ascend_config,
|
||||
mock_get_vllm_config):
|
||||
if vllm_version_is("0.11.0"):
|
||||
with patch("vllm_ascend.models.layers.mla.Attention",
|
||||
return_value=True):
|
||||
mock_tp_size.return_value = 1
|
||||
mock_ascend_config.return_value.enable_shared_expert_dp = False
|
||||
mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||
mock_vllm_config.model_config.hf_config = MagicMock(
|
||||
num_hidden_layers=32, first_k_dense_replace=False)
|
||||
mock_get_vllm_config.return_value = mock_vllm_config
|
||||
mock_vllm_config.compilation_config = CompilationConfig()
|
||||
|
||||
attn = AscendMultiHeadLatentAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
mla_modules=self.mock_mla_modules,
|
||||
cache_config=self.mock_cache_config,
|
||||
quant_config=self.mock_quant_config,
|
||||
prefix=self.prefix,
|
||||
)
|
||||
|
||||
self.assertEqual(attn.hidden_size, self.hidden_size)
|
||||
self.assertEqual(attn.kv_lora_rank, self.kv_lora_rank)
|
||||
self.assertEqual(attn.debug_layer_idx, 0)
|
||||
self.assertIsNotNone(attn.mla_attn)
|
||||
self.assertIn(
|
||||
self.prefix,
|
||||
mock_vllm_config.compilation_config.static_forward_context)
|
||||
else:
|
||||
with patch("vllm_ascend.models.layers.mla.MLAAttention",
|
||||
return_value=True):
|
||||
mock_tp_size.return_value = 2
|
||||
mock_ascend_config.return_value.enable_shared_expert_dp = True
|
||||
mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||
mock_vllm_config.model_config.hf_config = MagicMock(
|
||||
num_hidden_layers=32, first_k_dense_replace=True)
|
||||
mock_get_vllm_config.return_value = mock_vllm_config
|
||||
mock_vllm_config.compilation_config = CompilationConfig()
|
||||
|
||||
attn = AscendMultiHeadLatentAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
mla_modules=self.mock_mla_modules,
|
||||
cache_config=self.mock_cache_config,
|
||||
quant_config=self.mock_quant_config,
|
||||
prefix=self.prefix,
|
||||
)
|
||||
|
||||
self.assertEqual(attn.tp_size, 2)
|
||||
self.assertTrue(attn.enable_shared_expert_dp)
|
||||
self.assertIsNotNone(attn.mla_attn)
|
||||
|
||||
@patch("vllm_ascend.models.layers.mla.torch.ops.vllm.mla_forward")
|
||||
@patch("vllm_ascend.models.layers.mla.get_current_vllm_config")
|
||||
@patch("vllm_ascend.models.layers.mla.get_ascend_config")
|
||||
@patch(
|
||||
"vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size")
|
||||
@patch("vllm_ascend.models.layers.mla.get_forward_context")
|
||||
def test_forward(self, mock_get_forward_context, mock_tp_size,
|
||||
mock_ascend_config, mock_get_vllm_config,
|
||||
mock_mla_forward):
|
||||
mock_tp_size.return_value = 1
|
||||
mock_ascend_config.return_value.enable_shared_expert_dp = False
|
||||
mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||
mock_vllm_config.model_config.hf_config = MagicMock(
|
||||
num_hidden_layers=32, first_k_dense_replace=False)
|
||||
mock_get_vllm_config.return_value = mock_vllm_config
|
||||
mock_vllm_config.compilation_config = CompilationConfig()
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
with patch("vllm_ascend.models.layers.mla.Attention",
|
||||
return_value=True):
|
||||
attn = AscendMultiHeadLatentAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
mla_modules=self.mock_mla_modules,
|
||||
cache_config=self.mock_cache_config,
|
||||
quant_config=self.mock_quant_config,
|
||||
prefix=self.prefix,
|
||||
)
|
||||
else:
|
||||
with patch("vllm_ascend.models.layers.mla.MLAAttention",
|
||||
return_value=True):
|
||||
attn = AscendMultiHeadLatentAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
mla_modules=self.mock_mla_modules,
|
||||
cache_config=self.mock_cache_config,
|
||||
quant_config=self.mock_quant_config,
|
||||
prefix=self.prefix,
|
||||
)
|
||||
positions = torch.tensor([0, 1, 2])
|
||||
hidden_states = torch.randn(3, self.hidden_size)
|
||||
|
||||
mock_forward_context = MagicMock(spec=ForwardContext)
|
||||
mock_forward_context.sp_enabled = False
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_mla_forward.return_value = (3, self.hidden_size)
|
||||
|
||||
output = attn.forward(positions, hidden_states)
|
||||
|
||||
self.assertEqual(output.shape, (3, self.hidden_size))
|
||||
self.assertTrue(
|
||||
torch.allclose(output, output.view(-1, self.hidden_size)))
|
||||
Reference in New Issue
Block a user