[refactor](UT,PCP,DCP) refactor pcp&dcp patches in UTs (#5505)
### What this PR does / why we need it?
Refactor PCP & DCP patches in UTs: Merge and reuse communication groups
and communication function patches to reduce code duplication.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -1,71 +1,19 @@
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, all_gather_fake
|
||||
|
||||
from tests.ut.attention.utils import patch_distributed_groups
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl
|
||||
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
||||
AscendMetadataForPrefill)
|
||||
|
||||
|
||||
def patch_distributed_groups(dcp_size=1, dcp_rank=0, pcp_size=1, pcp_rank=0):
|
||||
"""
|
||||
Decorator to patch common distributed group mocks with configuration
|
||||
|
||||
Args:
|
||||
dcp_size: DCP world size (default: 1)
|
||||
dcp_rank: DCP rank (default: 0)
|
||||
pcp_size: PCP world size (default: 1)
|
||||
pcp_rank: PCP rank (default: 0)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@wraps(func)
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
def wrapper(self, mock_pcp, mock_all_to_all_single, *args, **kwargs):
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.rank_in_group = pcp_rank
|
||||
|
||||
mock_pcp.rank_in_group = pcp_rank
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.device_group = MagicMock()
|
||||
mock_pcp.all_gather = MagicMock()
|
||||
mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
||||
input_, dim, pcp_size, "mock")
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_(
|
||||
input)
|
||||
|
||||
return func(self, mock_all_to_all_single, mock_pcp, *args,
|
||||
**kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class TestAscendAttentionCPImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_pcp):
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = 0
|
||||
mock_dcp.device_group = MagicMock()
|
||||
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = 0
|
||||
mock_pcp.device_group = MagicMock()
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
self.layer._k_scale_float = 1.0
|
||||
@@ -134,41 +82,12 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 4)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP')
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
@patch("torch.distributed.all_gather")
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch('vllm_ascend.attention.attention_cp.get_forward_context')
|
||||
def test_forward_decode_pcp_dcp(self, mock_get_forward_context,
|
||||
mock_all_to_all_single, mock_all_gather,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_pcp_group):
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = 0
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 2
|
||||
dcp_group.device_group = MagicMock()
|
||||
dcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = 0
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 2
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_pcp_group.return_value = pcp_group
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
|
||||
mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score):
|
||||
query = torch.randn(2, 4, 128)
|
||||
self.impl.key_cache = torch.randn(100, 128, 1, 128)
|
||||
self.impl.value_cache = torch.randn(100, 128, 1, 128)
|
||||
@@ -185,15 +104,6 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
|
||||
def mock_all_gather_func1(tensor_list, tensor, group=None):
|
||||
tensor_list[0] = tensor
|
||||
tensor_list[1] = tensor.clone()
|
||||
|
||||
mock_all_gather.side_effect = mock_all_gather_func1
|
||||
|
||||
def mock_npu_fused_infer_attention_score_func(query, k_nope, value,
|
||||
**common_kwargs):
|
||||
mock_output = torch.randn_like(query)
|
||||
@@ -213,25 +123,10 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 4)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP')
|
||||
def test_prefill_query_all_gather(self, mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp, mock_get_pcp_group):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_prefill_query_all_gather(self):
|
||||
query = torch.randn(2, 4, 128)
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.prefill = MagicMock()
|
||||
attn_metadata.prefill.chunked_context = MagicMock()
|
||||
@@ -243,9 +138,9 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 8)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
||||
def test_compute_prefill_context(self, mock_npu_attention, mock_pcp_group):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_compute_prefill_context(self, mock_npu_attention):
|
||||
|
||||
block_num = 100
|
||||
block_size = 128
|
||||
@@ -284,13 +179,6 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.impl._load_kv_for_chunk = MagicMock()
|
||||
self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_npu_attention.return_value = torch.randn(batch_size, num_heads,
|
||||
head_size), torch.randn(
|
||||
batch_size,
|
||||
@@ -341,11 +229,9 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(value.shape[1], num_heads)
|
||||
self.assertEqual(value.shape[2], head_size)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_reshape_and_cache(self, mock_npu_reshape_and_cache, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_reshape_and_cache(self, mock_npu_reshape_and_cache):
|
||||
num_tokens = 4
|
||||
block_num = 100
|
||||
block_size = 128
|
||||
@@ -369,13 +255,6 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
key = torch.randn(num_tokens, num_heads, head_size)
|
||||
value = torch.randn(num_tokens, num_heads, head_size)
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
key, value = self.impl.reshape_and_cache(key, value, kv_cache,
|
||||
attn_metadata)
|
||||
self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size)
|
||||
@@ -388,30 +267,8 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
|
||||
class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
@patch_distributed_groups(needs_mocks=False)
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
self.layer._k_scale_float = 1.0
|
||||
@@ -443,6 +300,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
self.impl.pcp_size = 1
|
||||
self.impl.dcp_size = 1
|
||||
|
||||
self.batch_size = 2
|
||||
|
||||
@@ -730,7 +588,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=3)
|
||||
def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
mock_dcp, mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -759,7 +617,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch_distributed_groups(dcp_size=2)
|
||||
def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
mock_dcp, mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -789,7 +647,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch_distributed_groups(pcp_size=2)
|
||||
def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
mock_dcp, mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
@@ -50,30 +49,7 @@ class TestAscendAttentionBackend(TestBase):
|
||||
|
||||
class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
def setUp(self):
|
||||
self.mock_vllm_config = MagicMock()
|
||||
self.mock_vllm_config.speculative_config = None
|
||||
self.mock_vllm_config.model_config.max_model_len = 640
|
||||
@@ -126,30 +102,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
self.layer._k_scale_float = 1.0
|
||||
|
||||
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from tests.ut.attention.utils import patch_distributed_groups
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
@@ -159,29 +160,18 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
||||
mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp):
|
||||
mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
mock_tp.rank_in_group = MagicMock()
|
||||
mock_tp.device_group = MagicMock()
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = MagicMock()
|
||||
mock_dcp.device_group = MagicMock()
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = MagicMock()
|
||||
mock_pcp.device_group = MagicMock()
|
||||
vllm_config = MagicMock()
|
||||
speculative_config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
@@ -252,12 +242,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(self.impl.pcp_size, 2)
|
||||
self.assertEqual(self.impl.dcp_size, 2)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_cp.get_dcp_group')
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_mla_preprocess_dcp(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad,
|
||||
mock_get_dcp_group):
|
||||
mock_maybe_all_gather_and_maybe_unpad):
|
||||
|
||||
self.impl.num_kv_heads = 1
|
||||
self.impl.num_heads = 16
|
||||
@@ -278,14 +267,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.impl.qk_rope_head_dim)
|
||||
kv_cache = (kv_cache0, kv_cache1)
|
||||
|
||||
mock_dcp_group = MagicMock()
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_dcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dcp_group.return_value = mock_dcp_group
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.num_decodes = 2
|
||||
attn_metadata.num_prefills = 0
|
||||
@@ -337,12 +318,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertIsNone(prefill_res)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.mla_cp.get_pcp_group')
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_mla_preprocess_pcp(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad,
|
||||
mock_get_pcp_group,
|
||||
mock_npu_reshape_and_cache):
|
||||
self.impl.num_kv_heads = 1
|
||||
self.impl.num_heads = 16
|
||||
@@ -363,14 +343,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.impl.qk_rope_head_dim)
|
||||
kv_cache = (kv_cache0, kv_cache1)
|
||||
|
||||
mock_pcp_group = MagicMock()
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_pcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_pcp_group.return_value = mock_pcp_group
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.num_decodes = 0
|
||||
attn_metadata.num_prefills = 2
|
||||
@@ -451,10 +423,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertIsNone(decode_res)
|
||||
self.assertIsNotNone(prefill_res)
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
def test_process_attn_out_lse(self, mock_all_to_all_single, mock_pcp):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_process_attn_out_lse(self):
|
||||
self.impl.dcp_size = 2
|
||||
self.impl.pcp_size = 2
|
||||
|
||||
@@ -465,14 +435,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
attn_output = torch.randn(B, N, self.impl.kv_lora_rank)
|
||||
softmax_lse = torch.randn(B, N, 1)
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
|
||||
def make_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))
|
||||
|
||||
decode_metadata = MagicMock()
|
||||
decode_metadata.actual_seq_lengths_q = MagicMock()
|
||||
decode_metadata.seq_lens_list = MagicMock()
|
||||
@@ -486,16 +448,13 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
|
||||
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
||||
@patch('torch_npu.npu_attention_update')
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
|
||||
mock_npu_multi_head_latent_attention,
|
||||
mock_get_forward_context,
|
||||
mock_all_to_all_single, mock_pcp):
|
||||
mock_get_forward_context):
|
||||
self.impl.dcp_size = 2
|
||||
self.impl.pcp_size = 2
|
||||
self.impl.num_kv_heads = 1
|
||||
@@ -531,14 +490,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
]
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
|
||||
def make_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))
|
||||
|
||||
self.impl._v_up_proj = MagicMock()
|
||||
self.impl._v_up_proj.return_value = torch.randn(
|
||||
B, self.impl.v_head_dim)
|
||||
@@ -549,17 +500,12 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], self.impl.v_head_dim)
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("torch_npu.atb.npu_paged_cache_load")
|
||||
@patch("torch_npu.atb.npu_ring_mla")
|
||||
def test_compute_prefill_context_with_dcp_pcp(self, mock_ring, mock_load,
|
||||
mock_dcp, mock_pcp):
|
||||
|
||||
def mock_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp,
|
||||
mock_pcp, mock_ring,
|
||||
mock_load):
|
||||
|
||||
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen,
|
||||
head_num, kv_head_num, pre_out, prev_lse, qk_scale,
|
||||
@@ -620,10 +566,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
torch.ones(10, 10, dtype=torch.float16), 1)
|
||||
for test_case in test_cases:
|
||||
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
|
||||
mock_dcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(dcp_size))
|
||||
mock_pcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(pcp_size))
|
||||
mock_dcp.world_size = dcp_size
|
||||
mock_pcp.world_size = pcp_size
|
||||
assert len(nums_tokens_per_rank) == len(nums_all_rank_context)
|
||||
nums_context_per_rank = []
|
||||
for num_all_rank_context in nums_all_rank_context:
|
||||
@@ -687,18 +631,9 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_reorg_kvcache_with_dcp_pcp(self, mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp, mock_get_pcp_group):
|
||||
|
||||
def mock_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp,
|
||||
mock_pcp):
|
||||
BLOCK_SIZE = 128 # fixed
|
||||
max_model_len = 4096
|
||||
max_num_seqs = 25
|
||||
@@ -714,11 +649,12 @@ class TestAscendMLAImpl(TestBase):
|
||||
if pcp_size * dcp_size == 1:
|
||||
continue
|
||||
self.impl.dcp_size = dcp_size
|
||||
mock_dcp.world_size = dcp_size
|
||||
mock_dcp.all_gather.reset_mock()
|
||||
self.impl.pcp_size = pcp_size
|
||||
mock_dcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(dcp_size))
|
||||
mock_pcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(pcp_size))
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.all_gather.reset_mock()
|
||||
|
||||
chunked_prefill_workspace_size = min(
|
||||
max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE),
|
||||
128 * 1024)
|
||||
@@ -918,17 +854,17 @@ class TestAscendMLAImpl(TestBase):
|
||||
1 + (kv_with_q_tail_nomask_idx.shape[0] != 0))
|
||||
mock_npu_ring_mla.reset_mock()
|
||||
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_process_attn_out_lse_with_dcp_pcp(self, mock_pcp,
|
||||
mock_all_to_all):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_to_all, mock_dcp,
|
||||
mock_pcp):
|
||||
B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8]
|
||||
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)]
|
||||
for test_case in test_cases:
|
||||
print(test_case)
|
||||
self.impl.dcp_size = test_case[0]
|
||||
self.impl.pcp_size = test_case[1]
|
||||
mock_dcp.world_size = test_case[0]
|
||||
mock_pcp.world_size = test_case[1]
|
||||
# Inputs
|
||||
attn_output = torch.randn(B, H, D)
|
||||
softmax_lse = torch.randn(B, H, 1)
|
||||
@@ -936,17 +872,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
decode_meta = MagicMock()
|
||||
decode_meta.batch_seq_mask = batch_seq_mask
|
||||
|
||||
def mock_all_to_all_side_effect(output, input, group=None):
|
||||
output.copy_(input)
|
||||
|
||||
mock_all_to_all.side_effect = mock_all_to_all_side_effect
|
||||
|
||||
def mock_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
mock_pcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(self.impl.pcp_size))
|
||||
|
||||
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
||||
decode_meta)
|
||||
# [PCP * S, DCP * H, D + 1]
|
||||
|
||||
@@ -190,15 +190,7 @@ class TestAscendMLAMetadata(TestBase):
|
||||
|
||||
class TestAscendMLAMetadataBuilder(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_ascend_mla_metadata_builder_default(self, mock_dcp,
|
||||
mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_ascend_mla_metadata_builder_default(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -209,22 +201,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = 0
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 2
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = 0
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 2
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
ascend_config = MagicMock()
|
||||
@@ -239,16 +215,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp,
|
||||
mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -259,20 +226,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 3
|
||||
mock_vllm_config.speculative_config = mock_spec_config
|
||||
@@ -290,15 +243,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
self, mock_get_cos_and_sin_mla):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -310,20 +256,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_device = 'cpu'
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 1
|
||||
mock_spec_config.disable_padded_drafter_batch = True
|
||||
@@ -352,14 +284,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
[1, 2, 4, 5, 6, 6, 7, 8])
|
||||
self.assertEqual(metadata.decode.block_table.shape[0], 8)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_reorder_batch(self, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_reorder_batch(self):
|
||||
ascend_config = MagicMock()
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
@@ -370,20 +295,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
@@ -411,16 +322,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
self.assertTrue(modified)
|
||||
input_batch.swap_states.assert_called_once_with(1, 2)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_dcp,
|
||||
mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -432,20 +334,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_device = 'cpu'
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
input_seq_lens = [1, 2, 4, 5]
|
||||
@@ -456,15 +344,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
num_reqs_pad_size, num_reqs, input_seq_lens)
|
||||
self.assertEqual(output_seq_lens, expect_output)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_dcp,
|
||||
mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -476,20 +356,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_device = 'cpu'
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
common_metadata = MagicMock()
|
||||
common_metadata.actual_seq_lengths_q = [2, 4, 6, 8]
|
||||
|
||||
@@ -530,22 +396,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
mock_zeros,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
def zeros_override(*args, **kwargs):
|
||||
kwargs.pop('pin_memory', None)
|
||||
@@ -596,22 +454,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
mock_zeros,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
def zeros_override(*args, **kwargs):
|
||||
kwargs.pop('pin_memory', None)
|
||||
@@ -663,18 +513,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_decode_only_metadata(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
@@ -718,18 +559,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
def test_build_for_graph_capture_decode_only(self,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
@@ -774,17 +607,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
def test_build_for_graph_capture_prefill(self, mock_get_cos_and_sin_mla):
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 3, 7]),
|
||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
||||
@@ -820,23 +644,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
def setUp(self, get_current_vllm_config, mock_tp, mock_dcp, mock_pcp):
|
||||
def setUp(self, get_current_vllm_config, mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
mock_tp.rank_in_group = MagicMock()
|
||||
mock_tp.device_group = MagicMock()
|
||||
mock_dcp.world_size = 1
|
||||
mock_dcp.rank_in_group = MagicMock()
|
||||
mock_dcp.device_group = MagicMock()
|
||||
mock_pcp.world_size = 1
|
||||
mock_pcp.rank_in_group = MagicMock()
|
||||
mock_pcp.device_group = MagicMock()
|
||||
vllm_config = MagicMock()
|
||||
speculative_config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
|
||||
64
tests/ut/attention/utils.py
Normal file
64
tests/ut/attention/utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from functools import wraps
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from vllm.distributed.parallel_state import all_gather_fake
|
||||
|
||||
|
||||
def patch_distributed_groups(dcp_size=1,
|
||||
dcp_rank=0,
|
||||
pcp_size=1,
|
||||
pcp_rank=0,
|
||||
needs_mocks=True):
|
||||
"""
|
||||
Decorator to patch common distributed group mocks with configuration
|
||||
|
||||
Args:
|
||||
dcp_size: DCP world size (default: 1)
|
||||
dcp_rank: DCP rank (default: 0)
|
||||
pcp_size: PCP world size (default: 1)
|
||||
pcp_rank: PCP rank (default: 0)
|
||||
needs_mocks: Whether to pass mock objects as the first arguments
|
||||
after 'self' to the decorated function.
|
||||
If True, the decorated function receives:
|
||||
func(self, mock_all_to_all_single, mock_dcp, mock_pcp, *args, **kwargs)
|
||||
If False, mocks are not passed and function receives:
|
||||
func(self, *args, **kwargs)
|
||||
(default: True)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@wraps(func)
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('vllm.distributed.parallel_state._DCP')
|
||||
def wrapper(self, mock_dcp, mock_pcp, mock_all_to_all_single, *args,
|
||||
**kwargs):
|
||||
mock_dcp.rank_in_group = dcp_rank
|
||||
mock_dcp.world_size = dcp_size
|
||||
mock_dcp.device_group = MagicMock()
|
||||
|
||||
mock_dcp.all_gather = MagicMock()
|
||||
mock_dcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
||||
input_, dim, mock_dcp.world_size, "mock_dcp_group")
|
||||
|
||||
mock_pcp.rank_in_group = pcp_rank
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.device_group = MagicMock()
|
||||
|
||||
mock_pcp.all_gather = MagicMock()
|
||||
mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
||||
input_, dim, mock_pcp.world_size, "mock_pcp_group")
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_(
|
||||
input)
|
||||
|
||||
if needs_mocks:
|
||||
return func(self, mock_all_to_all_single, mock_dcp, mock_pcp,
|
||||
*args, **kwargs)
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -12,7 +12,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.utils.network_utils import make_zmq_path
|
||||
|
||||
fake_engine = types.ModuleType("mooncake.engine")
|
||||
@@ -23,6 +22,7 @@ _mock_ascend_config = MagicMock(enable_kv_nz=False)
|
||||
_mock_pp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||
_mock_tp_group = MagicMock(rank_in_group=0, world_size=4)
|
||||
_mock_pcp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||
_mock_dcp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
|
||||
return_value=_mock_pp_group).start()
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
|
||||
@@ -35,6 +35,7 @@ patch(
|
||||
return_value=0).start()
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
|
||||
return_value=_mock_pcp_group).start()
|
||||
patch('vllm.distributed.parallel_state._DCP', _mock_dcp_group).start()
|
||||
|
||||
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
|
||||
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
|
||||
@@ -1098,17 +1099,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
self.mock_transfer_engine.get_rpc_port.return_value = 9090
|
||||
self.mock_transfer_engine.initialize.return_value = 0
|
||||
self.mock_transfer_engine.register_memory.return_value = 0
|
||||
self.mock_dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
self.mock_dcp_group.rank_in_group = 0
|
||||
self.mock_dcp_group.world_size = 1
|
||||
self.mock_dcp_group.device_group = MagicMock()
|
||||
self.mock_dcp = MagicMock()
|
||||
self.mock_dcp.world_size = 1
|
||||
|
||||
self.mock_pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
self.mock_pcp_group.rank_in_group = 0
|
||||
self.mock_pcp_group.world_size = 1
|
||||
self.mock_pcp_group.device_group = MagicMock()
|
||||
|
||||
self.patches = [
|
||||
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
|
||||
@@ -1143,13 +1133,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
MagicMock()),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
|
||||
MagicMock()),
|
||||
patch('vllm.distributed.parallel_state.get_dcp_group',
|
||||
return_value=self.mock_dcp_group),
|
||||
patch('vllm.distributed.parallel_state._DCP',
|
||||
return_value=self.mock_dcp),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.get_decode_context_model_parallel_world_size',
|
||||
return_value=1),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
|
||||
return_value=MagicMock()),
|
||||
|
||||
Reference in New Issue
Block a user