revert TND modify when dcp pcp (#3948)
### What this PR does / why we need it? 1、revert TND modify when dcp pcp, which is introduced byf57bdb09fc2、deal aclgraph pad border issue - vLLM version: v0.11.0 - vLLM main:83f478bb19Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -869,6 +869,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
num_heads = self.num_heads
|
num_heads = self.num_heads
|
||||||
|
|
||||||
|
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
|
||||||
k_nope = self.key_cache.view(self.key_cache.shape[0],
|
k_nope = self.key_cache.view(self.key_cache.shape[0],
|
||||||
self.key_cache.shape[1], -1)
|
self.key_cache.shape[1], -1)
|
||||||
value = self.value_cache.view(self.key_cache.shape[0],
|
value = self.value_cache.view(self.key_cache.shape[0],
|
||||||
@@ -879,7 +880,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
'num_key_value_heads':
|
'num_key_value_heads':
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
'input_layout':
|
'input_layout':
|
||||||
"TND",
|
"BSND",
|
||||||
'atten_mask':
|
'atten_mask':
|
||||||
None,
|
None,
|
||||||
'scale':
|
'scale':
|
||||||
@@ -895,14 +896,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
'block_size':
|
'block_size':
|
||||||
self.key_cache.shape[1],
|
self.key_cache.shape[1],
|
||||||
'actual_seq_lengths_kv':
|
'actual_seq_lengths_kv':
|
||||||
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
|
attn_metadata.decode_meta.
|
||||||
'actual_seq_lengths':
|
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
|
||||||
attn_metadata.actual_seq_lengths_q[:attn_metadata.
|
|
||||||
num_decode_tokens]
|
|
||||||
}
|
}
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
num_tokens = query.shape[0]
|
num_tokens = q_nope.shape[0]
|
||||||
if forward_context.capturing:
|
if forward_context.capturing:
|
||||||
stream = torch_npu.npu.current_stream()
|
stream = torch_npu.npu.current_stream()
|
||||||
|
|
||||||
@@ -914,16 +913,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
workspace = graph_params.workspaces.get(num_tokens)
|
workspace = graph_params.workspaces.get(num_tokens)
|
||||||
if workspace is None:
|
if workspace is None:
|
||||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||||
query, k_nope, value, **common_kwargs)
|
q_nope, k_nope, value, **common_kwargs)
|
||||||
update_graph_params_workspaces(num_tokens,
|
update_graph_params_workspaces(num_tokens,
|
||||||
weak_ref_tensors(workspace))
|
weak_ref_tensors(workspace))
|
||||||
attn_out = torch.empty_like(query)
|
attn_out = torch.empty_like(q_nope)
|
||||||
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
|
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=query.device)
|
device=q_nope.device)
|
||||||
|
|
||||||
graph_params.attn_params[num_tokens].append(
|
graph_params.attn_params[num_tokens].append(
|
||||||
(weak_ref_tensors(query), weak_ref_tensors(k_nope),
|
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||||
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
|
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
|
||||||
self.scale, attn_metadata.block_tables,
|
self.scale, attn_metadata.block_tables,
|
||||||
self.key_cache.shape[1], attn_metadata.decode_meta.
|
self.key_cache.shape[1], attn_metadata.decode_meta.
|
||||||
@@ -933,7 +932,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.pcp_rank, self.dcp_rank, self.dcp_size))
|
self.pcp_rank, self.dcp_rank, self.dcp_size))
|
||||||
torch.npu.graph_task_group_begin(stream)
|
torch.npu.graph_task_group_begin(stream)
|
||||||
torch_npu.npu_fused_infer_attention_score.out(
|
torch_npu.npu_fused_infer_attention_score.out(
|
||||||
query,
|
q_nope,
|
||||||
k_nope,
|
k_nope,
|
||||||
value,
|
value,
|
||||||
**common_kwargs,
|
**common_kwargs,
|
||||||
@@ -943,7 +942,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
graph_params.handles[num_tokens].append(handle)
|
graph_params.handles[num_tokens].append(handle)
|
||||||
else:
|
else:
|
||||||
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||||
query, k_nope, value, **common_kwargs)
|
q_nope, k_nope, value, **common_kwargs)
|
||||||
|
|
||||||
|
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
|
||||||
|
attn_out.shape[3])
|
||||||
|
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
|
||||||
|
|
||||||
attn_out_lse_list = []
|
attn_out_lse_list = []
|
||||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||||
@@ -1017,7 +1020,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
prefill_query, key, value, attn_metadata,
|
prefill_query, key, value, attn_metadata,
|
||||||
output[num_decode_tokens:], prefill_query.shape[0])
|
output[num_decode_tokens:], prefill_query.shape[0])
|
||||||
attn_metadata.seq_lens = seq_lens_back
|
attn_metadata.seq_lens = seq_lens_back
|
||||||
output[num_decode_tokens:] = output_prefill
|
output[num_decode_tokens:output_prefill.shape[0] +
|
||||||
|
num_decode_tokens] = output_prefill
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1089,7 +1093,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
if has_prefill:
|
if has_prefill:
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
kv = torch.cat([key, value], dim=-1)
|
kv = torch.cat([key, value], dim=-1)
|
||||||
all_kv = get_pcp_group().all_gather(kv, dim=0)
|
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||||
|
all_kv = get_pcp_group().all_gather(
|
||||||
|
kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
|
||||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
|
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
|
||||||
all_kv = torch.index_select(all_kv, 0,
|
all_kv = torch.index_select(all_kv, 0,
|
||||||
pcp_allgather_restore_idx)
|
pcp_allgather_restore_idx)
|
||||||
|
|||||||
@@ -301,9 +301,9 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
|||||||
):
|
):
|
||||||
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
|
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
|
||||||
block_table, block_size, actual_seq_lengths_kv, attn_output,
|
block_table, block_size, actual_seq_lengths_kv, attn_output,
|
||||||
softmax_lse, cp_rank, dcp_rank, dcp_size) = param
|
softmax_lse, pcp_rank, dcp_rank, dcp_size) = param
|
||||||
actual_seq_lengths_kv = forward_context.attn_metadata[
|
actual_seq_lengths_kv = forward_context.attn_metadata[
|
||||||
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
|
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank,
|
||||||
dcp_rank]
|
dcp_rank]
|
||||||
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||||
pad_tensor = np.zeros(pad_length,
|
pad_tensor = np.zeros(pad_length,
|
||||||
|
|||||||
@@ -476,6 +476,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
self.num_actual_tokens_pcp_padded = 0
|
||||||
if self.speculative_config and self.pcp_size > 1:
|
if self.speculative_config and self.pcp_size > 1:
|
||||||
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -1915,7 +1916,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
hidden_states = hidden_states[:-pad_size, :]
|
hidden_states = hidden_states[:-pad_size, :]
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
hidden_states = get_pcp_group().all_gather(
|
||||||
|
hidden_states[:self.num_actual_tokens_pcp_padded //
|
||||||
|
self.pcp_size], 0)
|
||||||
hidden_states = torch.index_select(
|
hidden_states = torch.index_select(
|
||||||
hidden_states, 0,
|
hidden_states, 0,
|
||||||
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
||||||
@@ -4304,6 +4307,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||||
|
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||||
long_seq_metadata = None
|
long_seq_metadata = None
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||||
|
|||||||
Reference in New Issue
Block a user