diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 7522f8b..92f5532 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -189,11 +189,20 @@ class ACLGraphWrapper: return entry.output -def update_attn_params(update_stream, forward_context, runtime_shape): +def update_attn_params(update_stream, + forward_context, + runtime_shape, + kv_transfer_config=None): graph_params = get_graph_params() - # FIXME: Behold! We are using a temporary hack here to update the args - # for each layer's attention op in the graph. - with torch.npu.stream(update_stream): + + # NOTE(Angazenn): By moving the npu-stream context ahead, + # (see https://github.com/vllm-project/vllm-ascend/pull/3985) + # we can reduce host overhead introduced by stream initialization. + # However, we find that this might cause potential accuracy problems + # with pd-disaggreagation. Therefore, this optimization is only enabled + # without pd-disaggreagation. We are working on to solve this problem + # directly int the future. + if kv_transfer_config is not None: for key, param, handle, event in zip( forward_context.attn_metadata, graph_params.attn_params[runtime_shape], @@ -215,10 +224,9 @@ def update_attn_params(update_stream, forward_context, runtime_shape): # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY # mode with GQA. This is triggered by getting workspace for _npu_paged_attention - # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens - # might encounter a bigger workspace, while currently we use max_model_len to - # calculate max workspace in capturing. So additional get_workspace is added - # here to avoid such bugs. + # in torch_npu. On some cases, _npu_paged_attention requires different workspace + # among various seq_lens. So additional get_workspace is added here + # to avoid such bugs. # TODO(Angazenn): we will remove this once _npu_paged_attention is fully # replaced by npu_fused_infer_attention_score which does not contain such bugs. workspace = torch_npu._npu_paged_attention_get_workspace( @@ -231,20 +239,67 @@ def update_attn_params(update_stream, forward_context, runtime_shape): block_table=block_table, context_lens=seq_lens, out=output) - torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention(query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=workspace) - torch.npu.graph_task_update_end(update_stream) - event.record(update_stream) + with torch.npu.stream(update_stream): + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + else: + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + seq_lens = forward_context.attn_metadata[key].seq_lens + + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) def update_mla_attn_params(update_stream, forward_context, runtime_shape, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6fcb93b..238f494 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1598,7 +1598,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.speculative_config) else: update_attn_params(self.update_stream, forward_context, - maybe_padded_num_tokens) + maybe_padded_num_tokens, + self.vllm_config.kv_transfer_config) if get_forward_context().sp_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) @@ -2359,7 +2360,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_tokens, self.speculative_config) else: update_attn_params(self.update_stream, forward_context, - num_tokens) + num_tokens, + self.vllm_config.kv_transfer_config) if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states