[v0.18.0][Bugfix] fix ds3.2 dcp mtp (#7681)
### What this PR does / why we need it? Fixed the issue where the DCP overlaps the MTP scenario in the ds3.2 scenario. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? cherry-pick from: https://github.com/vllm-project/vllm-ascend/pull/7617 Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -36,7 +36,7 @@ deployment:
|
|||||||
--no-enable-prefix-caching
|
--no-enable-prefix-caching
|
||||||
--gpu-memory-utilization 0.85
|
--gpu-memory-utilization 0.85
|
||||||
--trust-remote-code
|
--trust-remote-code
|
||||||
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
|
--speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp"}'
|
||||||
--compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}'
|
--compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}'
|
||||||
--additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}'
|
--additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}'
|
||||||
--tokenizer-mode deepseek_v32
|
--tokenizer-mode deepseek_v32
|
||||||
@@ -62,7 +62,7 @@ deployment:
|
|||||||
--no-enable-prefix-caching
|
--no-enable-prefix-caching
|
||||||
--gpu-memory-utilization 0.85
|
--gpu-memory-utilization 0.85
|
||||||
--trust-remote-code
|
--trust-remote-code
|
||||||
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
|
--speculative-config '{"num_speculative_tokens": 3, "method":"deepseek_mtp"}'
|
||||||
--compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}'
|
--compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}'
|
||||||
--additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}'
|
--additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}'
|
||||||
--tokenizer-mode deepseek_v32
|
--tokenizer-mode deepseek_v32
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class TestAscendMLAMetadata(TestBase):
|
|||||||
|
|
||||||
metadata = AscendMLAMetadata(
|
metadata = AscendMLAMetadata(
|
||||||
num_actual_tokens_pcp_padded, num_actual_tokens, slot_mapping,
|
num_actual_tokens_pcp_padded, num_actual_tokens, slot_mapping,
|
||||||
query_start_loc, seq_lens, block_tables, num_decodes,
|
query_start_loc, seq_lens, seq_lens, block_tables, num_decodes,
|
||||||
num_decode_tokens, num_prefills, num_input_tokens, query_lens,
|
num_decode_tokens, num_prefills, num_input_tokens, query_lens,
|
||||||
head_dim, attn_mask, attn_state, decode, prefill)
|
head_dim, attn_mask, attn_state, decode, prefill)
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class TestAscendSFAMetadata(TestBase):
|
|||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_cpu=seq_lens,
|
||||||
cum_query_lens=cum_query_lens,
|
cum_query_lens=cum_query_lens,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
sin=sin,
|
sin=sin,
|
||||||
|
|||||||
@@ -803,6 +803,7 @@ class TestPCPDCPGraphParams(TestBase):
|
|||||||
slot_mapping,
|
slot_mapping,
|
||||||
query_start_loc,
|
query_start_loc,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
seq_lens,
|
||||||
block_tables,
|
block_tables,
|
||||||
4,
|
4,
|
||||||
4,
|
4,
|
||||||
|
|||||||
@@ -168,6 +168,7 @@ class AscendMetadata:
|
|||||||
# should simplified these parameters once attention schema in vLLM-Ascend
|
# should simplified these parameters once attention schema in vLLM-Ascend
|
||||||
# is unified.
|
# is unified.
|
||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None
|
||||||
|
seq_lens_cpu: torch.Tensor = None
|
||||||
seq_lens_list: list[int] = None # type: ignore
|
seq_lens_list: list[int] = None # type: ignore
|
||||||
actual_seq_lengths_q: list[int] = None # type: ignore
|
actual_seq_lengths_q: list[int] = None # type: ignore
|
||||||
|
|
||||||
@@ -307,6 +308,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
|||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_cpu=seq_lens,
|
||||||
seq_lens_list=seq_lens.tolist(),
|
seq_lens_list=seq_lens.tolist(),
|
||||||
max_query_len=common_attn_metadata.max_query_len,
|
max_query_len=common_attn_metadata.max_query_len,
|
||||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||||
|
|||||||
@@ -232,6 +232,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_cpu=seq_lens,
|
||||||
seq_lens_list=seq_lens.tolist(),
|
seq_lens_list=seq_lens.tolist(),
|
||||||
max_query_len=common_attn_metadata.max_query_len,
|
max_query_len=common_attn_metadata.max_query_len,
|
||||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||||
|
|||||||
@@ -257,8 +257,8 @@ class AscendSFACPImpl(AscendSFAImpl):
|
|||||||
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
|
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
|
||||||
|
|
||||||
def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None:
|
def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None:
|
||||||
if attn_output is None:
|
if attn_output is None or self.pcp_size == 1:
|
||||||
return None
|
return attn_output
|
||||||
# In graph/piecewise mode, output buffer uses graph bucket token size
|
# In graph/piecewise mode, output buffer uses graph bucket token size
|
||||||
# (forward_context.num_tokens), while PCP path may compute only valid
|
# (forward_context.num_tokens), while PCP path may compute only valid
|
||||||
# tokens. Align to the larger one to avoid later write-back mismatch.
|
# tokens. Align to the larger one to avoid later write-back mismatch.
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ class AscendMLAMetadata:
|
|||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
|
seq_lens_cpu: torch.Tensor
|
||||||
block_tables: torch.Tensor
|
block_tables: torch.Tensor
|
||||||
|
|
||||||
# New for MLA (compared to FlashAttention)
|
# New for MLA (compared to FlashAttention)
|
||||||
@@ -457,6 +458,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
block_tables=self.block_table,
|
block_tables=self.block_table,
|
||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
|
seq_lens_cpu=self.seq_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_chunked_metadata(
|
def build_chunked_metadata(
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ class AscendSFAMetadata:
|
|||||||
num_actual_tokens: int # Number of tokens excluding padding.
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
|
seq_lens_cpu: torch.Tensor
|
||||||
cum_query_lens: torch.Tensor
|
cum_query_lens: torch.Tensor
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
sin: torch.Tensor
|
sin: torch.Tensor
|
||||||
@@ -233,6 +234,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
|
|
||||||
cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
|
cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1]
|
||||||
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||||
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||||
|
|
||||||
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
cos, sin = get_cos_and_sin_mla(input_positions, True)
|
||||||
|
|
||||||
@@ -320,6 +322,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
cum_query_lens=cum_query_lens,
|
cum_query_lens=cum_query_lens,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_cpu=seq_lens_cpu,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
head_dim=self.model_config.get_head_size(),
|
head_dim=self.model_config.get_head_size(),
|
||||||
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
|
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
|
||||||
|
|||||||
@@ -600,7 +600,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
|
num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens
|
||||||
ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone()
|
ori_seq_len = attn_metadata_i.seq_lens_cpu[:batch_size].clone()
|
||||||
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
|
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
|
||||||
|
|
||||||
# slot_mapping index base offset:
|
# slot_mapping index base offset:
|
||||||
@@ -1247,6 +1247,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
|
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
|
if getattr(attn_metadata, "decode", None):
|
||||||
attn_metadata.decode.cp_seq_len = cp_seq_len
|
attn_metadata.decode.cp_seq_len = cp_seq_len
|
||||||
else:
|
else:
|
||||||
attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
||||||
|
|||||||
Reference in New Issue
Block a user