[Perf][PCP][DCP] add multi-stream for GQA to enable computation-communication overlap (#5382)
### What this PR does / why we need it?
This PR adds multi-stream for GQA to enable computation-communication
overlap. For chunked prefill, we reduce TTFT by approximately 4%.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: release/v0.13.0
- vLLM main:
bc0a5a0c08
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, all_gather_fake
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl
|
||||
@@ -10,33 +11,60 @@ from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
||||
AscendMetadataForPrefill)
|
||||
|
||||
|
||||
def patch_distributed_groups(dcp_size=1, dcp_rank=0, pcp_size=1, pcp_rank=0):
|
||||
"""
|
||||
Decorator to patch common distributed group mocks with configuration
|
||||
|
||||
Args:
|
||||
dcp_size: DCP world size (default: 1)
|
||||
dcp_rank: DCP rank (default: 0)
|
||||
pcp_size: PCP world size (default: 1)
|
||||
pcp_rank: PCP rank (default: 0)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@wraps(func)
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
def wrapper(self, mock_pcp, mock_all_to_all_single, *args, **kwargs):
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.rank_in_group = pcp_rank
|
||||
|
||||
mock_pcp.rank_in_group = pcp_rank
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.device_group = MagicMock()
|
||||
mock_pcp.all_gather = MagicMock()
|
||||
mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
||||
input_, dim, pcp_size, "mock")
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_(
|
||||
input)
|
||||
|
||||
return func(self, mock_all_to_all_single, mock_pcp, *args,
|
||||
**kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class TestAscendAttentionCPImpl(TestBase):
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_pcp):
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = 0
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 2
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
mock_dcp.device_group = MagicMock()
|
||||
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = 0
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 2
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
mock_pcp.device_group = MagicMock()
|
||||
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
@@ -237,10 +265,14 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.prefill = MagicMock()
|
||||
attn_metadata.prefill.chunked_context = MagicMock()
|
||||
attn_metadata.prefill.chunked_context.local_context_lens_allranks = torch.tensor(
|
||||
[[[256, 256], [256, 256]]])
|
||||
local_context_lens_allranks = torch.tensor([[[256, 256], [256, 256]]])
|
||||
attn_metadata.prefill.chunked_context.local_context_lens_allranks = local_context_lens_allranks
|
||||
attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint(
|
||||
0, 2, (1024, ), dtype=torch.bool)
|
||||
attn_metadata.prefill.chunked_context.local_total_toks = local_context_lens_allranks[:,
|
||||
0,
|
||||
0].sum(
|
||||
)
|
||||
|
||||
def mock_load_kv_for_chunk(attn_metadata, kv_cache,
|
||||
local_chunked_kv_lens_rank, query,
|
||||
@@ -264,8 +296,17 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
batch_size,
|
||||
num_heads, 1)
|
||||
|
||||
result_output, result_lse = self.impl._compute_prefill_context(
|
||||
context_output = self.impl._compute_prefill_context(
|
||||
query, kv_cache, attn_metadata)
|
||||
local_context_output = torch.cat(context_output,
|
||||
dim=-1).permute([1, 2,
|
||||
0]).contiguous()
|
||||
global_context_output = self.impl._gather_global_context_output(
|
||||
local_context_output)
|
||||
global_context_output = global_context_output.permute([2, 0, 1
|
||||
]).contiguous()
|
||||
result_output, result_lse = self.impl._update_global_context_output(
|
||||
global_context_output)
|
||||
|
||||
self.assertEqual(result_output.shape[0], batch_size)
|
||||
self.assertEqual(result_output.shape[1], self.impl.num_heads)
|
||||
@@ -687,37 +728,25 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
self.assertIsInstance(out_final, torch.Tensor)
|
||||
self.assertIsInstance(lse_final, torch.Tensor)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('torch.cat')
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('torch.distributed.all_gather')
|
||||
@patch('torch.stack')
|
||||
@patch('torch.split')
|
||||
def test_update_chunk_attn_out_lse_dcp_pcp_both_greater_than_1(
|
||||
self, mock_split, mock_stack, mock_all_gather,
|
||||
mock_all_to_all_single, mock_cat, mock_pcp, mock_get_pcp_group):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=3)
|
||||
def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
self.impl.dcp_size = 2
|
||||
self.impl.pcp_size = 3
|
||||
self.impl.head_size = 8
|
||||
# Mock output
|
||||
mock_cat.return_value = torch.randn(2, 4, 9)
|
||||
mock_all_to_all_single.return_value = torch.randn(4, 9, 2)
|
||||
mock_all_gather.return_value = [(2, 4, 9), (2, 4, 9), (2, 4, 9)]
|
||||
mock_stack.return_value = torch.randn(6, 2, 2, 9)
|
||||
mock_split.return_value = (torch.randn(6, 2, 2,
|
||||
8), torch.randn(6, 2, 2, 1))
|
||||
mock_pcp_group = MagicMock()
|
||||
mock_pcp_group.all_gather.return_value = torch.randn(6, 4, 9)
|
||||
mock_get_pcp_group.return_value = mock_pcp_group
|
||||
|
||||
# Call the method under test
|
||||
output, lse = self.impl._update_chunk_attn_out_lse(
|
||||
prefix_chunk_output, prefix_chunk_lse)
|
||||
chunk_data = torch.cat([prefix_chunk_output, prefix_chunk_lse],
|
||||
dim=-1).permute([1, 2, 0]).contiguous()
|
||||
global_context_output = self.impl._gather_global_context_output(
|
||||
chunk_data)
|
||||
global_context_output = global_context_output.permute([2, 0, 1
|
||||
]).contiguous()
|
||||
output, lse = self.impl._update_global_context_output(
|
||||
global_context_output)
|
||||
|
||||
# Assert the method call
|
||||
self.assertIsInstance(output, torch.Tensor)
|
||||
@@ -725,21 +754,12 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
self.assertEqual(output.shape, (2, 2, 8))
|
||||
self.assertEqual(lse.shape, (2, 2, 1))
|
||||
|
||||
self.assertEqual(mock_cat.call_count, 1)
|
||||
mock_all_to_all_single.assert_called_once()
|
||||
self.assertEqual(mock_get_pcp_group.call_count, 1)
|
||||
mock_pcp.all_gather.assert_called_once()
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('torch.cat')
|
||||
@patch('torch.chunk')
|
||||
@patch('torch.stack')
|
||||
@patch('torch.split')
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('torch.distributed.all_gather')
|
||||
def test_update_chunk_attn_out_lse_dcp_greater_than_1_only(
|
||||
self, mock_all_gather, mock_all_to_all_single, mock_split,
|
||||
mock_stack, mock_chunk, mock_cat, mock_pcp, mock_pcp_group):
|
||||
@patch_distributed_groups(dcp_size=2)
|
||||
def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -748,20 +768,15 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
self.impl.pcp_size = 1
|
||||
self.impl.head_size = 8
|
||||
|
||||
# Mock output
|
||||
mock_cat.return_value = torch.randn(2, 4, 9)
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
mock_chunk.return_value = [torch.randn(2, 2, 9), torch.randn(2, 2, 9)]
|
||||
mock_stack.return_value = torch.randn(2, 2, 2, 9)
|
||||
mock_split.return_value = [
|
||||
torch.randn(2, 2, 2, 8),
|
||||
torch.randn(2, 2, 2, 1)
|
||||
]
|
||||
|
||||
# Call the method under test
|
||||
output, lse = self.impl._update_chunk_attn_out_lse(
|
||||
prefix_chunk_output, prefix_chunk_lse)
|
||||
chunk_data = torch.cat([prefix_chunk_output, prefix_chunk_lse],
|
||||
dim=-1).permute([1, 2, 0]).contiguous()
|
||||
global_context_output = self.impl._gather_global_context_output(
|
||||
chunk_data)
|
||||
global_context_output = global_context_output.permute([2, 0, 1
|
||||
]).contiguous()
|
||||
output, lse = self.impl._update_global_context_output(
|
||||
global_context_output)
|
||||
|
||||
# Assert the method call
|
||||
self.assertIsInstance(output, torch.Tensor)
|
||||
@@ -769,24 +784,12 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
self.assertEqual(output.shape, (2, 2, 8))
|
||||
self.assertEqual(lse.shape, (2, 2, 1))
|
||||
|
||||
self.assertEqual(mock_cat.call_count, 1)
|
||||
mock_all_to_all_single.assert_called_once()
|
||||
mock_all_gather.assert_not_called()
|
||||
mock_pcp.all_gather.assert_not_called()
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('torch.cat')
|
||||
@patch('torch.stack')
|
||||
@patch('torch.split')
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('torch.distributed.all_gather')
|
||||
@patch(
|
||||
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
|
||||
)
|
||||
def test_update_chunk_attn_out_lse_pcp_greater_than_1_only(
|
||||
self, mock_update_out_and_lse, mock_all_gather,
|
||||
mock_all_to_all_single, mock_split, mock_stack, mock_cat, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
@patch_distributed_groups(pcp_size=2)
|
||||
def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -795,30 +798,21 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
self.impl.pcp_size = 2
|
||||
self.impl.head_size = 8
|
||||
|
||||
# Mock output
|
||||
mock_cat.return_value = torch.randn(2, 4, 9)
|
||||
mock_pcp_group = MagicMock()
|
||||
mock_pcp_group.all_gather.return_value = torch.randn(4, 4, 9)
|
||||
mock_get_pcp_group.return_value = mock_pcp_group
|
||||
mock_stack.return_value = torch.randn(2, 2, 4, 9)
|
||||
mock_split.return_value = [
|
||||
torch.randn(2, 2, 4, 8),
|
||||
torch.randn(2, 2, 4, 1)
|
||||
]
|
||||
mock_update_out_and_lse.return_value = torch.randn(2, 4,
|
||||
8), torch.randn(
|
||||
2, 4, 1)
|
||||
# Call the method under test
|
||||
output, lse = self.impl._update_chunk_attn_out_lse(
|
||||
prefix_chunk_output, prefix_chunk_lse)
|
||||
chunk_data = torch.cat([prefix_chunk_output, prefix_chunk_lse],
|
||||
dim=-1).permute([1, 2, 0]).contiguous()
|
||||
global_context_output = self.impl._gather_global_context_output(
|
||||
chunk_data)
|
||||
global_context_output = global_context_output.permute([2, 0, 1
|
||||
]).contiguous()
|
||||
output, lse = self.impl._update_global_context_output(
|
||||
global_context_output)
|
||||
|
||||
# Assert the method call
|
||||
self.assertIsInstance(output, torch.Tensor)
|
||||
self.assertIsInstance(lse, torch.Tensor)
|
||||
self.assertEqual(output.shape, (2, 4, 8))
|
||||
self.assertEqual(lse.shape, (2, 4, 1))
|
||||
self.impl._update_out_and_lse.assert_called_once()
|
||||
|
||||
self.assertEqual(mock_cat.call_count, 1)
|
||||
mock_all_to_all_single.assert_not_called()
|
||||
mock_get_pcp_group.assert_called_once()
|
||||
mock_pcp.all_gather.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user