[feature]dcp&pcp support mlapo (#5672)
### What this PR does / why we need it?
mlapo in deepseek is a huge performance improvement in decode, this pr
support pcp & dcp with mlapo
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
This commit is contained in:
@@ -16,8 +16,8 @@ To learn more about the theory and implementation details of context parallel, p
|
|||||||
Currently context parallel can be used together with most other features, supported features are as follows:
|
Currently context parallel can be used together with most other features, supported features are as follows:
|
||||||
| | Eager | Graph | Prefix <br> Cache | Chunked <br> Prefill | SpecDecode <br> (MTP) | PD <br> disaggregation | MLAPO |
|
| | Eager | Graph | Prefix <br> Cache | Chunked <br> Prefill | SpecDecode <br> (MTP) | PD <br> disaggregation | MLAPO |
|
||||||
| ------- | ----- | ----- | ------ | ------ | ----- | ----- | ----- |
|
| ------- | ----- | ----- | ------ | ------ | ----- | ----- | ----- |
|
||||||
| **PCP** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| **PCP** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅|
|
||||||
| **DCP** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| **DCP** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
|
||||||
## How to use Context Parallel
|
## How to use Context Parallel
|
||||||
You can enable `PCP` and `DCP` by `prefill_context_parallel_size` and `decode_context_parallel_size`, refer to the following example:
|
You can enable `PCP` and `DCP` by `prefill_context_parallel_size` and `decode_context_parallel_size`, refer to the following example:
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ def test_qwen3_next_mtp_acceptance_tp4(model_name):
|
|||||||
for num_accepted_tokens in num_accepted_tokens_per_pos
|
for num_accepted_tokens in num_accepted_tokens_per_pos
|
||||||
]
|
]
|
||||||
|
|
||||||
match = all(abs(a - b) < 0.05 for a, b in zip(acceptance_per_pos, golden))
|
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
|
||||||
if not match:
|
if not match:
|
||||||
print(f"acceptance_per_pos: {acceptance_per_pos}")
|
print(f"acceptance_per_pos: {acceptance_per_pos}")
|
||||||
print(f"golden: {golden}")
|
print(f"golden: {golden}")
|
||||||
|
|||||||
@@ -278,6 +278,7 @@ class TestMtpProposer:
|
|||||||
[0, 8, 16, 24], dtype=torch.int32)
|
[0, 8, 16, 24], dtype=torch.int32)
|
||||||
mock_common_attn_metadata.seq_lens = torch.tensor([8, 8, 8],
|
mock_common_attn_metadata.seq_lens = torch.tensor([8, 8, 8],
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
mock_common_attn_metadata.num_actual_tokens = 24
|
||||||
mock_common_attn_metadata.num_reqs = 3
|
mock_common_attn_metadata.num_reqs = 3
|
||||||
mock_common_attn_metadata.num_computed_tokens_cpu = torch.tensor(
|
mock_common_attn_metadata.num_computed_tokens_cpu = torch.tensor(
|
||||||
[5, 6, 7], dtype=torch.int32)
|
[5, 6, 7], dtype=torch.int32)
|
||||||
@@ -293,10 +294,12 @@ class TestMtpProposer:
|
|||||||
mock_runner.actual_seq_lengths_q = MagicMock()
|
mock_runner.actual_seq_lengths_q = MagicMock()
|
||||||
mock_runner.attn_state = MagicMock()
|
mock_runner.attn_state = MagicMock()
|
||||||
mock_runner.graph_pad_size = 0
|
mock_runner.graph_pad_size = 0
|
||||||
|
mock_runner.pcp_size = 1
|
||||||
mock_runner.decode_token_per_req = MagicMock()
|
mock_runner.decode_token_per_req = MagicMock()
|
||||||
|
|
||||||
proposer = MagicMock(spec=MtpProposer)
|
proposer = MagicMock(spec=MtpProposer)
|
||||||
proposer.runner = mock_runner
|
proposer.runner = mock_runner
|
||||||
|
proposer.pcp_size = 1
|
||||||
proposer.arange = torch.arange(100, dtype=torch.int32)
|
proposer.arange = torch.arange(100, dtype=torch.int32)
|
||||||
proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__(
|
proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__(
|
||||||
proposer)
|
proposer)
|
||||||
|
|||||||
@@ -70,6 +70,26 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> AscendMLAMetadata:
|
||||||
|
metadata_cls = super().build(common_prefix_len, common_attn_metadata)
|
||||||
|
if self.num_prefills == 0 and self.pcp_size > 1:
|
||||||
|
self.slot_mapping[:self.
|
||||||
|
num_decode_tokens] = self.slot_mapping[:self.
|
||||||
|
num_decode_tokens
|
||||||
|
* self.
|
||||||
|
pcp_size:
|
||||||
|
self.
|
||||||
|
pcp_size]
|
||||||
|
self.slot_mapping[self.num_decode_tokens:self.num_decode_tokens *
|
||||||
|
self.pcp_size].fill_(-1)
|
||||||
|
metadata_cls.slot_mapping = self.slot_mapping
|
||||||
|
return metadata_cls
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cudagraph_support(
|
def get_cudagraph_support(
|
||||||
cls: type["AscendMlaCPMetadataBuilder"],
|
cls: type["AscendMlaCPMetadataBuilder"],
|
||||||
@@ -363,8 +383,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
decode_ql_nope, decode_q_pe = self.reorg_decode_q(
|
decode_ql_nope, decode_q_pe = self.reorg_decode_q(
|
||||||
decode_ql_nope, decode_q_pe)
|
decode_ql_nope, decode_q_pe)
|
||||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||||
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens *
|
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
|
||||||
self.pcp_size:self.pcp_size]
|
|
||||||
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
||||||
decode_k_pe, decode_k_nope = self.exec_kv_decode(
|
decode_k_pe, decode_k_nope = self.exec_kv_decode(
|
||||||
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
||||||
|
|||||||
@@ -438,7 +438,6 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
if self.num_decodes > 0:
|
if self.num_decodes > 0:
|
||||||
decode_metadata = self.build_decode_metadata(
|
decode_metadata = self.build_decode_metadata(
|
||||||
common_prefix_len, common_attn_metadata)
|
common_prefix_len, common_attn_metadata)
|
||||||
|
|
||||||
return self.metadata_cls( # type: ignore
|
return self.metadata_cls( # type: ignore
|
||||||
num_actual_tokens_pcp_padded=self.num_actual_tokens,
|
num_actual_tokens_pcp_padded=self.num_actual_tokens,
|
||||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||||
@@ -1330,7 +1329,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.W_UK_T,
|
self.W_UK_T,
|
||||||
decode_k_nope,
|
decode_k_nope,
|
||||||
decode_k_pe,
|
decode_k_pe,
|
||||||
attn_metadata.slot_mapping[:bsz].flatten(),
|
attn_metadata.slot_mapping[:bsz],
|
||||||
quant_scale0=self.quant_scale0,
|
quant_scale0=self.quant_scale0,
|
||||||
quant_offset0=self.quant_offset0,
|
quant_offset0=self.quant_offset0,
|
||||||
bias0=self.quant_bias_qkv,
|
bias0=self.quant_bias_qkv,
|
||||||
|
|||||||
@@ -800,7 +800,8 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||||
num_reqs=common_attn_metadata.num_reqs,
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
num_actual_tokens=total_num_tokens,
|
num_actual_tokens=common_attn_metadata.num_actual_tokens
|
||||||
|
if self.pcp_size > 1 else total_num_tokens,
|
||||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||||
max_query_len=new_query_len_per_req.max().item(),
|
max_query_len=new_query_len_per_req.max().item(),
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
|
|||||||
@@ -912,12 +912,15 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.input_batch)
|
self.input_batch)
|
||||||
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
|
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
slot_mapping = self.pcp_manager.get_padded_slot_mapping(
|
slot_mapping_pcp = self.pcp_manager.get_padded_slot_mapping(
|
||||||
total_num_scheduled_tokens,
|
total_num_scheduled_tokens,
|
||||||
slot_mapping,
|
slot_mapping,
|
||||||
)
|
)
|
||||||
blk_table.slot_mapping.gpu[:self.pcp_manager.
|
blk_table.slot_mapping.gpu[:self.pcp_manager.
|
||||||
num_actual_tokens_pcp_padded] = slot_mapping
|
num_actual_tokens_pcp_padded] = slot_mapping_pcp
|
||||||
|
slot_mapping = blk_table.slot_mapping.gpu[:self.
|
||||||
|
pcp_manager.
|
||||||
|
num_actual_tokens_pcp_padded]
|
||||||
|
|
||||||
# NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs
|
# NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs
|
||||||
# has been split to multiple parts, and there are 3 parts that is related to this
|
# has been split to multiple parts, and there are 3 parts that is related to this
|
||||||
|
|||||||
Reference in New Issue
Block a user