diff --git a/docs/source/developer_guide/evaluation/index.md b/docs/source/developer_guide/evaluation/index.md index 6bb9044..324c2e2 100644 --- a/docs/source/developer_guide/evaluation/index.md +++ b/docs/source/developer_guide/evaluation/index.md @@ -12,4 +12,5 @@ using_evalscope :caption: Performance :maxdepth: 1 performance_benchmark +profile_execute_duration ::: \ No newline at end of file diff --git a/docs/source/developer_guide/evaluation/profile_execute_duration.md b/docs/source/developer_guide/evaluation/profile_execute_duration.md new file mode 100644 index 0000000..8989bf9 --- /dev/null +++ b/docs/source/developer_guide/evaluation/profile_execute_duration.md @@ -0,0 +1,34 @@ +# Profile Execute Duration + +The execution duration of each stage (including pre/post-processing, model forward, etc.) usually needs to be captured during a complete inference process. Typically, this is done by using `torch.npu.synchronize()` and obtaining CPU timestamps, which increases the performance overhead of host/device synchronization. + +**To reduce the performance overhead, we add this feature, using the NPU event timestamp mechanism to observe the device execution time asynchronously.** + +## Usage +* Use the environment variable `VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE` to enable this feature. +* Use the non-blocking API `ProfileExecuteDuration().capture_async` to set observation points asynchronously when you need to observe the execution duration. +* Use the blocking API `ProfileExecuteDuration().pop_captured_sync` at an appropriate time to get and print the execution durations of all observed stages. + +## Example Output + +``` +5691:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.17ms [prepare input and forward]:9.57ms [forward]:4.14ms +5695:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.29ms [prepare input and forward]:10.19ms [forward]:4.14ms +5697:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.81ms [prepare input and forward]:10.29ms [forward]:3.99ms +5701:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.10ms [prepare input and forward]:10.62ms [forward]:4.33ms +5705:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.65ms [prepare input and forward]:9.58ms [forward]:4.20ms +5709:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.43ms [prepare input and forward]:9.88ms [forward]:4.20ms +5711:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.89ms [prepare input and forward]:10.49ms [forward]:4.19ms +5715:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.14ms [prepare input and forward]:11.21ms [forward]:4.18ms +5719:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.71ms [prepare input and forward]:10.15ms [forward]:4.42ms +5723:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.31ms [forward]:4.25ms +5725:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.12ms [prepare input and forward]:10.33ms [forward]:4.24ms +5729:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.58ms [prepare input and forward]:10.85ms [forward]:4.32ms +5733:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.32ms [prepare input and forward]:9.79ms [forward]:4.28ms +5737:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:15.06ms [prepare input and forward]:9.89ms [forward]:4.32ms +5739:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.48ms [forward]:4.27ms +5743:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.60ms [prepare input and forward]:10.71ms [forward]:4.61ms +5747:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.21ms [prepare input and forward]:10.10ms [forward]:4.52ms +5751:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:15.03ms [prepare input and forward]:10.00ms [forward]:4.42ms + +``` \ No newline at end of file diff --git a/tests/singlecard/test_profile_execute_duration.py b/tests/singlecard/test_profile_execute_duration.py new file mode 100644 index 0000000..449526e --- /dev/null +++ b/tests/singlecard/test_profile_execute_duration.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import time +from unittest.mock import patch + +import torch +import vllm # noqa: F401 + +from vllm_ascend.utils import ProfileExecuteDuration + + +@patch.dict(os.environ, {"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "1"}) +def test_execue_duration_enabled_discrepancy(): + a = torch.randn(10000, 10000).npu() + b = torch.randn(10000, 10000).npu() + + # warmup + torch.matmul(a, b) + torch.npu.synchronize() + + cpu_start = time.perf_counter() + with ProfileExecuteDuration().capture_async("forward"): + torch.matmul(a, b) + torch.npu.synchronize() + cpu_duration = (time.perf_counter() - cpu_start) * 1000 + npu_durations = ProfileExecuteDuration().pop_captured_sync() + assert npu_durations and 'forward' in npu_durations + assert not ProfileExecuteDuration._observations + + # Assert discrepancy between CPU and NPU duration is within 50% roughly + diff = abs(cpu_duration - npu_durations['forward']) / max( + cpu_duration, npu_durations['forward']) + assert diff <= 0.5, ( + f"CPU={cpu_duration:.2f}ms, NPU={npu_durations['forward']:.2f}ms") + + +def test_execue_duration_disabled(): + a = torch.randn(100, 100).npu() + b = torch.randn(100, 100).npu() + + with ProfileExecuteDuration().capture_async("forward"): + torch.matmul(a, b) + torch.npu.synchronize() + npu_durations = ProfileExecuteDuration().pop_captured_sync() + assert not npu_durations diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 9378e6f..52e50fb 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -70,6 +70,9 @@ env_variables: Dict[str, Callable[[], Any]] = { lambda: os.getenv("VLLM_VERSION", None), "VLLM_ASCEND_TRACE_RECOMPILES": lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))), + "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": + lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) + ), } # end-env-vars-definition diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 34b8da1..7d40938 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -17,11 +17,15 @@ # Adapted from vllm-project/vllm/vllm/worker/worker.py # +import atexit import math -from typing import TYPE_CHECKING +from contextlib import contextmanager +from threading import Lock +from typing import TYPE_CHECKING, List, Tuple import torch from packaging.version import InvalidVersion, Version +from torch_npu.npu.streams import Event from vllm.logger import logger import vllm_ascend.envs as envs @@ -175,3 +179,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: def dispose_tensor(x: torch.Tensor): x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype)) + + +class ProfileExecuteDuration: + _instance = None + _observations: List[Tuple[str, Event, Event]] = [] + _lock = Lock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + atexit.register(cls._instance.destroy) + return cls._instance + + def destroy(self): + with self._lock: + self._observations.clear() + + @contextmanager + def capture_async(self, duration_tag: str): + if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE: + yield + return + + observe_start = Event(enable_timing=True) + observe_start.record() + try: + yield + finally: + observe_end = Event(enable_timing=True) + observe_end.record() + with self._lock: + self._observations.append( + (duration_tag, observe_start, observe_end)) + + def pop_captured_sync(self) -> dict: + """Pop and synchronize all events in the observation list""" + durations: dict[str, float] = {} + if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE: + return durations + + while self._observations: + with self._lock: + tag, observe_start, observe_end = self._observations.pop() + observe_end.synchronize() + durations[tag] = observe_start.elapsed_time(observe_end) + + return durations diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 58c7350..647176c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -67,7 +67,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler -from vllm_ascend.utils import vllm_version_is +from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer if TYPE_CHECKING: @@ -707,27 +707,28 @@ class NPUModelRunner(LoRAModelRunnerMixin): with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - model_kwargs = {} - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = self.kv_caches - model_kwargs["attn_metadata"] = attn_metadata - if self.torchair_graph_enabled and not with_prefill: - hidden_states = self.compile_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=None, - **model_kwargs, - ) - else: - assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=None, - **model_kwargs, - ) + with ProfileExecuteDuration().capture_async("forward"): + model_kwargs = {} + if self.torchair_graph_enabled: + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + if self.torchair_graph_enabled and not with_prefill: + hidden_states = self.compile_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + **model_kwargs, + ) + else: + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + **model_kwargs, + ) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -920,103 +921,121 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOuptut if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - (attn_metadata, hidden_states, spec_decode_metadata, positions, - num_scheduled_tokens, - sample_indices) = (self._process_reqs(scheduler_output, - intermediate_tensors)) - logits = self.model.compute_logits(hidden_states[sample_indices], None) + with ProfileExecuteDuration().capture_async( + "prepare input and forward"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + # Return empty ModelRunnerOuptut if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + (attn_metadata, hidden_states, spec_decode_metadata, positions, + num_scheduled_tokens, + sample_indices) = (self._process_reqs(scheduler_output, + intermediate_tensors)) - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - logits = self.apply_grammar_bitmask(scheduler_output, logits) + with ProfileExecuteDuration().capture_async("post process"): + logits = self.model.compute_logits(hidden_states[sample_indices], + None) - # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.sampling_metadata - if spec_decode_metadata is None: - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) - else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + logits = self.apply_grammar_bitmask(scheduler_output, logits) - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + bonus_logits = logits[ + spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[ + spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: + # Ignore the sampled token. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Get the valid generated tokens. + sampled_token_ids = sampler_output.sampled_token_ids + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + + spec_token_ids = self._get_spec_token_ids( + valid_sampled_token_ids, sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token. - # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - - # NOTE: NPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Get the valid generated tokens. - sampled_token_ids = sampler_output.sampled_token_ids - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, + scheduler_output, + spec_decode_metadata, + positions, + num_scheduled_tokens, + hidden_states, + attn_metadata, ) - spec_token_ids = self._get_spec_token_ids( - valid_sampled_token_ids, - sampling_metadata, - scheduler_output, - spec_decode_metadata, - positions, - num_scheduled_tokens, - hidden_states, - attn_metadata, - ) + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict={}, + ) + + durations = ProfileExecuteDuration().pop_captured_sync() + if durations: + dr_str = [ + f"[{tag}]:{duration:.2f}ms" + for tag, duration in durations.items() + ] + captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" + print(f"Profile execute duration [{captured_name}]:", + " ".join(dr_str)) - model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict={}, - ) return model_runner_output def _profile_multimodal(self) -> None: