diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 17ca56b7..41476ccc 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -194,26 +194,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - ( - query, - key_cache, - value_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - output, - ) = param - seq_lens = forward_context.attn_metadata[key].seq_lens - - with torch.npu.stream(update_stream): + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + seq_lens = forward_context.attn_metadata[key].seq_lens torch.npu.graph_task_update_begin(update_stream, handle) torch_npu._npu_paged_attention( query=query, @@ -236,30 +235,32 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, - spec_attn_mask, sparse_mode, scale, block_table, block_size, - seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param - seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list - if speculative_config and speculative_config.method == "deepseek_mtp": - actual_seq_lengths = forward_context.attn_metadata[ - key].decode.actual_seq_lengths_q - spec_multiple = speculative_config.num_speculative_tokens + 1 - seq_lens_list = seq_lens_list + [0] * ( - runtime_shape // spec_multiple - len(seq_lens_list)) - actual_seq_lengths = [ - spec_multiple * (i + 1) - for i in range(runtime_shape // spec_multiple) - ] - else: - seq_lens_list = seq_lens_list + [0] * (runtime_shape - - len(seq_lens_list)) - with torch.npu.stream(update_stream): + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, + spec_attn_mask, sparse_mode, scale, block_table, block_size, + seq_lens_list, actual_seq_lengths, attn_output, + softmax_lse) = param + seq_lens_list = forward_context.attn_metadata[ + key].decode.seq_lens_list + if speculative_config and speculative_config.method == "deepseek_mtp": + actual_seq_lengths = forward_context.attn_metadata[ + key].decode.actual_seq_lengths_q + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_lens_list = seq_lens_list + [0] * ( + runtime_shape // spec_multiple - len(seq_lens_list)) + actual_seq_lengths = [ + spec_multiple * (i + 1) + for i in range(runtime_shape // spec_multiple) + ] + else: + seq_lens_list = seq_lens_list + [0] * (runtime_shape - + len(seq_lens_list)) torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( @@ -291,26 +292,27 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - (q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table, - block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank, - dcp_rank, dcp_size) = param - actual_seq_lengths_kv = forward_context.attn_metadata[ - key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank, - dcp_rank] - pad_length = runtime_shape - len(actual_seq_lengths_kv) - pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype) - actual_seq_lengths_kv = np.concatenate( - [actual_seq_lengths_kv, pad_tensor]) - if dcp_size > 1: - num_heads = num_heads * dcp_size + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, value, num_heads, num_kv_heads, scale, + block_table, block_size, actual_seq_lengths_kv, attn_output, + softmax_lse, cp_rank, dcp_rank, dcp_size) = param + actual_seq_lengths_kv = forward_context.attn_metadata[ + key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank, + dcp_rank] + pad_length = runtime_shape - len(actual_seq_lengths_kv) + pad_tensor = np.zeros(pad_length, + dtype=actual_seq_lengths_kv.dtype) + actual_seq_lengths_kv = np.concatenate( + [actual_seq_lengths_kv, pad_tensor]) + if dcp_size > 1: + num_heads = num_heads * dcp_size - with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( @@ -340,30 +342,30 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale, - num_kv_heads, attn_output, softmax_lse) = param + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, + scale, num_kv_heads, attn_output, softmax_lse) = param - decode_meta = forward_context.attn_metadata[key].decode - seq_len = decode_meta.cp_seq_len + decode_meta = forward_context.attn_metadata[key].decode + seq_len = decode_meta.cp_seq_len - if speculative_config and speculative_config.method == "deepseek_mtp": - spec_multiple = speculative_config.num_speculative_tokens + 1 - seq_len = seq_len + [0] * (runtime_shape // spec_multiple - - len(seq_len)) - else: - pad_length = runtime_shape - len(seq_len) - pad_tensor = torch.zeros(pad_length, - dtype=seq_len.dtype, - device=seq_len.device) - seq_len = torch.cat([seq_len, pad_tensor], dim=0) + if speculative_config and speculative_config.method == "deepseek_mtp": + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_len = seq_len + [0] * (runtime_shape // spec_multiple - + len(seq_len)) + else: + pad_length = runtime_shape - len(seq_len) + pad_tensor = torch.zeros(pad_length, + dtype=seq_len.dtype, + device=seq_len.device) + seq_len = torch.cat([seq_len, pad_tensor], dim=0) - with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.atb.npu_multi_head_latent_attention(