revert TND modify when dcp pcp (#3948)

### What this PR does / why we need it?
1、revert TND modify when dcp pcp, which is introduced by
f57bdb09fc
2、deal aclgraph pad border issue

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-11-03 22:22:17 +08:00
committed by GitHub
parent cc2cd42ad3
commit 5453033a41
3 changed files with 27 additions and 17 deletions

View File

@@ -869,6 +869,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
else:
num_heads = self.num_heads
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
k_nope = self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1)
value = self.value_cache.view(self.key_cache.shape[0],
@@ -879,7 +880,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
'num_key_value_heads':
self.num_kv_heads,
'input_layout':
"TND",
"BSND",
'atten_mask':
None,
'scale':
@@ -895,14 +896,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
'block_size':
self.key_cache.shape[1],
'actual_seq_lengths_kv':
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
'actual_seq_lengths':
attn_metadata.actual_seq_lengths_q[:attn_metadata.
num_decode_tokens]
attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
num_tokens = q_nope.shape[0]
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
@@ -914,16 +913,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query, k_nope, value, **common_kwargs)
q_nope, k_nope, value, **common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_out = torch.empty_like(query)
attn_out = torch.empty_like(q_nope)
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
dtype=torch.float,
device=query.device)
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(k_nope),
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
self.scale, attn_metadata.block_tables,
self.key_cache.shape[1], attn_metadata.decode_meta.
@@ -933,7 +932,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.pcp_rank, self.dcp_rank, self.dcp_size))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
query,
q_nope,
k_nope,
value,
**common_kwargs,
@@ -943,7 +942,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.handles[num_tokens].append(handle)
else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
query, k_nope, value, **common_kwargs)
q_nope, k_nope, value, **common_kwargs)
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
attn_out.shape[3])
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
attn_out_lse_list = []
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
@@ -1017,7 +1020,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
prefill_query, key, value, attn_metadata,
output[num_decode_tokens:], prefill_query.shape[0])
attn_metadata.seq_lens = seq_lens_back
output[num_decode_tokens:] = output_prefill
output[num_decode_tokens:output_prefill.shape[0] +
num_decode_tokens] = output_prefill
return output
def forward(
@@ -1089,7 +1093,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
if has_prefill:
if self.pcp_size > 1:
kv = torch.cat([key, value], dim=-1)
all_kv = get_pcp_group().all_gather(kv, dim=0)
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
all_kv = get_pcp_group().all_gather(
kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
all_kv = torch.index_select(all_kv, 0,
pcp_allgather_restore_idx)

View File

@@ -301,9 +301,9 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
block_table, block_size, actual_seq_lengths_kv, attn_output,
softmax_lse, cp_rank, dcp_rank, dcp_size) = param
softmax_lse, pcp_rank, dcp_rank, dcp_size) = param
actual_seq_lengths_kv = forward_context.attn_metadata[
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank,
dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
pad_tensor = np.zeros(pad_length,

View File

@@ -476,6 +476,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.num_actual_tokens_pcp_padded = 0
if self.speculative_config and self.pcp_size > 1:
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
@@ -1915,7 +1916,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states = hidden_states[:-pad_size, :]
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
hidden_states = get_pcp_group().all_gather(
hidden_states[:self.num_actual_tokens_pcp_padded //
self.pcp_size], 0)
hidden_states = torch.index_select(
hidden_states, 0,
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
@@ -4304,6 +4307,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = AscendPrefillContextParallelMetadata(