[BugFix][MLA] Fix attn_mask bug for ring mla (#2704)
This PR fix a bug related to attention mask used in ring mla. Current
ring mla has supported compressed mask, so we can directly use a 512 *
512 attention mask.
- vLLM version: v0.10.1.1
- vLLM main:
b5ee1e3261
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -238,13 +238,19 @@ class TestAscendMLAImpl(TestBase):
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch("vllm.config.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
|
||||
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
||||
mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
vllm_config = MagicMock()
|
||||
speculative_config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
model_config.dtype = torch.float16
|
||||
vllm_config.model_config = model_config
|
||||
get_current_vllm_config.return_value = vllm_config
|
||||
|
||||
num_heads = 256
|
||||
head_size = 1024
|
||||
@@ -622,4 +628,4 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], HD)
|
||||
self.assertEqual(result.shape[2], HD)
|
||||
|
||||
@@ -483,10 +483,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
self.prefill_mask = None
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = get_current_vllm_config().speculative_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
if speculative_config is not None:
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
@@ -681,14 +683,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
device=q_nope.device)
|
||||
if self.prefill_mask is None:
|
||||
self.prefill_mask = torch.triu(
|
||||
torch.ones(512,
|
||||
512,
|
||||
torch.ones(self.ring_mla_mask_size,
|
||||
self.ring_mla_mask_size,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype),
|
||||
1) # 512: mask only support 512
|
||||
if attn_metadata.num_prefills > 1:
|
||||
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(
|
||||
attn_metadata.num_prefills, 1, 1)
|
||||
dtype=q_nope.dtype), 1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
|
||||
Reference in New Issue
Block a user