[BugFix][MLA] Fix attn_mask bug for ring mla (#2704)

This PR fix a bug related to attention mask used in ring mla. Current
ring mla has supported compressed mask, so we can directly use a 512 *
512 attention mask.

- vLLM version: v0.10.1.1
- vLLM main:
b5ee1e3261

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-09-04 10:22:46 +08:00
committed by GitHub
parent e11a1bbfc1
commit a58013440a
2 changed files with 15 additions and 11 deletions

View File

@@ -238,13 +238,19 @@ class TestAscendMLAImpl(TestBase):
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm.config.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
mock_tp):
mock_tp.world_size = 2
vllm_config = MagicMock()
speculative_config = MagicMock()
model_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
model_config.dtype = torch.float16
vllm_config.model_config = model_config
get_current_vllm_config.return_value = vllm_config
num_heads = 256
head_size = 1024
@@ -622,4 +628,4 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], HD)
self.assertEqual(result.shape[2], HD)