revert TND modify when dcp pcp (#3948)
### What this PR does / why we need it? 1、revert TND modify when dcp pcp, which is introduced byf57bdb09fc2、deal aclgraph pad border issue - vLLM version: v0.11.0 - vLLM main:83f478bb19Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -869,6 +869,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
else:
|
||||
num_heads = self.num_heads
|
||||
|
||||
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
|
||||
k_nope = self.key_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1)
|
||||
value = self.value_cache.view(self.key_cache.shape[0],
|
||||
@@ -879,7 +880,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
'num_key_value_heads':
|
||||
self.num_kv_heads,
|
||||
'input_layout':
|
||||
"TND",
|
||||
"BSND",
|
||||
'atten_mask':
|
||||
None,
|
||||
'scale':
|
||||
@@ -895,14 +896,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
'block_size':
|
||||
self.key_cache.shape[1],
|
||||
'actual_seq_lengths_kv':
|
||||
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
|
||||
'actual_seq_lengths':
|
||||
attn_metadata.actual_seq_lengths_q[:attn_metadata.
|
||||
num_decode_tokens]
|
||||
attn_metadata.decode_meta.
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
num_tokens = q_nope.shape[0]
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
@@ -914,16 +913,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query, k_nope, value, **common_kwargs)
|
||||
q_nope, k_nope, value, **common_kwargs)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
attn_out = torch.empty_like(query)
|
||||
attn_out = torch.empty_like(q_nope)
|
||||
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
|
||||
dtype=torch.float,
|
||||
device=query.device)
|
||||
device=q_nope.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(query), weak_ref_tensors(k_nope),
|
||||
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
|
||||
self.scale, attn_metadata.block_tables,
|
||||
self.key_cache.shape[1], attn_metadata.decode_meta.
|
||||
@@ -933,7 +932,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.pcp_rank, self.dcp_rank, self.dcp_size))
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query,
|
||||
q_nope,
|
||||
k_nope,
|
||||
value,
|
||||
**common_kwargs,
|
||||
@@ -943,7 +942,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||
query, k_nope, value, **common_kwargs)
|
||||
q_nope, k_nope, value, **common_kwargs)
|
||||
|
||||
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
|
||||
attn_out.shape[3])
|
||||
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
|
||||
|
||||
attn_out_lse_list = []
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
@@ -1017,7 +1020,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
prefill_query, key, value, attn_metadata,
|
||||
output[num_decode_tokens:], prefill_query.shape[0])
|
||||
attn_metadata.seq_lens = seq_lens_back
|
||||
output[num_decode_tokens:] = output_prefill
|
||||
output[num_decode_tokens:output_prefill.shape[0] +
|
||||
num_decode_tokens] = output_prefill
|
||||
return output
|
||||
|
||||
def forward(
|
||||
@@ -1089,7 +1093,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
if has_prefill:
|
||||
if self.pcp_size > 1:
|
||||
kv = torch.cat([key, value], dim=-1)
|
||||
all_kv = get_pcp_group().all_gather(kv, dim=0)
|
||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
all_kv = get_pcp_group().all_gather(
|
||||
kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
|
||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
|
||||
all_kv = torch.index_select(all_kv, 0,
|
||||
pcp_allgather_restore_idx)
|
||||
|
||||
@@ -301,9 +301,9 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
):
|
||||
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
|
||||
block_table, block_size, actual_seq_lengths_kv, attn_output,
|
||||
softmax_lse, cp_rank, dcp_rank, dcp_size) = param
|
||||
softmax_lse, pcp_rank, dcp_rank, dcp_size) = param
|
||||
actual_seq_lengths_kv = forward_context.attn_metadata[
|
||||
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
|
||||
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank,
|
||||
dcp_rank]
|
||||
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||
pad_tensor = np.zeros(pad_length,
|
||||
|
||||
@@ -476,6 +476,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_actual_tokens_pcp_padded = 0
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
@@ -1915,7 +1916,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states = hidden_states[:-pad_size, :]
|
||||
|
||||
if self.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states[:self.num_actual_tokens_pcp_padded //
|
||||
self.pcp_size], 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0,
|
||||
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
||||
@@ -4304,6 +4307,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
long_seq_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
|
||||
Reference in New Issue
Block a user