[ModelRunner]Add profile execute duration observation (#1013)
### What this PR does / why we need it? We need to **observe the time consumed in each stage of inference (including pre-processing, model forward, etc.), without any performance loss**. Therefore, we use the event timestamp mechanism of the NPU to mark any stage during the execution of the NPU device (this marking operation is executed asynchronously, with no performance loss). Additionally, we provide a blocking synchronization API `pop_captured_sync` to be called at an appropriate time, to print the time consumed in all observed stages. **model_runner_v1.py file only changed 5 lines, all of which were `ProfileExecuteDuration()` calls, and nothing else was changed, while more changes were showed due to the alignment issue.** ### Does this PR introduce _any_ user-facing change? Use env `VLLM_MODEL_EXECUTE_TIME_OBSERVE `to enable this feature ### How was this patch tested? Tested in deepseek model,Print like this: ``` 5691:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.17ms [prepare input and forward]:9.57ms [forward]:4.14ms 5695:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.29ms [prepare input and forward]:10.19ms [forward]:4.14ms 5697:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.81ms [prepare input and forward]:10.29ms [forward]:3.99ms 5701:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.10ms [prepare input and forward]:10.62ms [forward]:4.33ms 5705:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.65ms [prepare input and forward]:9.58ms [forward]:4.20ms 5709:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.43ms [prepare input and forward]:9.88ms [forward]:4.20ms 5711:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.89ms [prepare input and forward]:10.49ms [forward]:4.19ms 5715:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.14ms [prepare input and forward]:11.21ms [forward]:4.18ms 5719:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.71ms [prepare input and forward]:10.15ms [forward]:4.42ms 5723:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.31ms [forward]:4.25ms 5725:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.12ms [prepare input and forward]:10.33ms [forward]:4.24ms 5729:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.58ms [prepare input and forward]:10.85ms [forward]:4.32ms 5733:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.32ms [prepare input and forward]:9.79ms [forward]:4.28ms 5737:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:15.06ms [prepare input and forward]:9.89ms [forward]:4.32ms 5739:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.48ms [forward]:4.27ms 5743:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.60ms [prepare input and forward]:10.71ms [forward]:4.61ms 5747:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.21ms [prepare input and forward]:10.10ms [forward]:4.52ms 5751:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:15.03ms [prepare input and forward]:10.00ms [forward]:4.42ms ``` --------- Signed-off-by: depeng1994 <depengzhang@foxmail.com>
This commit is contained in:
@@ -12,4 +12,5 @@ using_evalscope
|
||||
:caption: Performance
|
||||
:maxdepth: 1
|
||||
performance_benchmark
|
||||
profile_execute_duration
|
||||
:::
|
||||
@@ -0,0 +1,34 @@
|
||||
# Profile Execute Duration
|
||||
|
||||
The execution duration of each stage (including pre/post-processing, model forward, etc.) usually needs to be captured during a complete inference process. Typically, this is done by using `torch.npu.synchronize()` and obtaining CPU timestamps, which increases the performance overhead of host/device synchronization.
|
||||
|
||||
**To reduce the performance overhead, we add this feature, using the NPU event timestamp mechanism to observe the device execution time asynchronously.**
|
||||
|
||||
## Usage
|
||||
* Use the environment variable `VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE` to enable this feature.
|
||||
* Use the non-blocking API `ProfileExecuteDuration().capture_async` to set observation points asynchronously when you need to observe the execution duration.
|
||||
* Use the blocking API `ProfileExecuteDuration().pop_captured_sync` at an appropriate time to get and print the execution durations of all observed stages.
|
||||
|
||||
## Example Output
|
||||
|
||||
```
|
||||
5691:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.17ms [prepare input and forward]:9.57ms [forward]:4.14ms
|
||||
5695:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.29ms [prepare input and forward]:10.19ms [forward]:4.14ms
|
||||
5697:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.81ms [prepare input and forward]:10.29ms [forward]:3.99ms
|
||||
5701:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.10ms [prepare input and forward]:10.62ms [forward]:4.33ms
|
||||
5705:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.65ms [prepare input and forward]:9.58ms [forward]:4.20ms
|
||||
5709:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.43ms [prepare input and forward]:9.88ms [forward]:4.20ms
|
||||
5711:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.89ms [prepare input and forward]:10.49ms [forward]:4.19ms
|
||||
5715:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.14ms [prepare input and forward]:11.21ms [forward]:4.18ms
|
||||
5719:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.71ms [prepare input and forward]:10.15ms [forward]:4.42ms
|
||||
5723:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.31ms [forward]:4.25ms
|
||||
5725:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.12ms [prepare input and forward]:10.33ms [forward]:4.24ms
|
||||
5729:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.58ms [prepare input and forward]:10.85ms [forward]:4.32ms
|
||||
5733:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.32ms [prepare input and forward]:9.79ms [forward]:4.28ms
|
||||
5737:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:15.06ms [prepare input and forward]:9.89ms [forward]:4.32ms
|
||||
5739:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.48ms [forward]:4.27ms
|
||||
5743:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.60ms [prepare input and forward]:10.71ms [forward]:4.61ms
|
||||
5747:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.21ms [prepare input and forward]:10.10ms [forward]:4.52ms
|
||||
5751:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:15.03ms [prepare input and forward]:10.00ms [forward]:4.42ms
|
||||
|
||||
```
|
||||
62
tests/singlecard/test_profile_execute_duration.py
Normal file
62
tests/singlecard/test_profile_execute_duration.py
Normal file
@@ -0,0 +1,62 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import vllm # noqa: F401
|
||||
|
||||
from vllm_ascend.utils import ProfileExecuteDuration
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "1"})
|
||||
def test_execue_duration_enabled_discrepancy():
|
||||
a = torch.randn(10000, 10000).npu()
|
||||
b = torch.randn(10000, 10000).npu()
|
||||
|
||||
# warmup
|
||||
torch.matmul(a, b)
|
||||
torch.npu.synchronize()
|
||||
|
||||
cpu_start = time.perf_counter()
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
torch.matmul(a, b)
|
||||
torch.npu.synchronize()
|
||||
cpu_duration = (time.perf_counter() - cpu_start) * 1000
|
||||
npu_durations = ProfileExecuteDuration().pop_captured_sync()
|
||||
assert npu_durations and 'forward' in npu_durations
|
||||
assert not ProfileExecuteDuration._observations
|
||||
|
||||
# Assert discrepancy between CPU and NPU duration is within 50% roughly
|
||||
diff = abs(cpu_duration - npu_durations['forward']) / max(
|
||||
cpu_duration, npu_durations['forward'])
|
||||
assert diff <= 0.5, (
|
||||
f"CPU={cpu_duration:.2f}ms, NPU={npu_durations['forward']:.2f}ms")
|
||||
|
||||
|
||||
def test_execue_duration_disabled():
|
||||
a = torch.randn(100, 100).npu()
|
||||
b = torch.randn(100, 100).npu()
|
||||
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
torch.matmul(a, b)
|
||||
torch.npu.synchronize()
|
||||
npu_durations = ProfileExecuteDuration().pop_captured_sync()
|
||||
assert not npu_durations
|
||||
@@ -70,6 +70,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: os.getenv("VLLM_VERSION", None),
|
||||
"VLLM_ASCEND_TRACE_RECOMPILES":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
||||
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
||||
),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@@ -17,11 +17,15 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
||||
#
|
||||
|
||||
import atexit
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
import torch
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from torch_npu.npu.streams import Event
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
@@ -175,3 +179,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
|
||||
|
||||
|
||||
class ProfileExecuteDuration:
|
||||
_instance = None
|
||||
_observations: List[Tuple[str, Event, Event]] = []
|
||||
_lock = Lock()
|
||||
|
||||
def __new__(cls):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
atexit.register(cls._instance.destroy)
|
||||
return cls._instance
|
||||
|
||||
def destroy(self):
|
||||
with self._lock:
|
||||
self._observations.clear()
|
||||
|
||||
@contextmanager
|
||||
def capture_async(self, duration_tag: str):
|
||||
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
||||
yield
|
||||
return
|
||||
|
||||
observe_start = Event(enable_timing=True)
|
||||
observe_start.record()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
observe_end = Event(enable_timing=True)
|
||||
observe_end.record()
|
||||
with self._lock:
|
||||
self._observations.append(
|
||||
(duration_tag, observe_start, observe_end))
|
||||
|
||||
def pop_captured_sync(self) -> dict:
|
||||
"""Pop and synchronize all events in the observation list"""
|
||||
durations: dict[str, float] = {}
|
||||
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
||||
return durations
|
||||
|
||||
while self._observations:
|
||||
with self._lock:
|
||||
tag, observe_start, observe_end = self._observations.pop()
|
||||
observe_end.synchronize()
|
||||
durations[tag] = observe_start.elapsed_time(observe_end)
|
||||
|
||||
return durations
|
||||
|
||||
@@ -67,7 +67,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, vllm_version_is
|
||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -707,27 +707,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
model_kwargs = {}
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
hidden_states = self.compile_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
model_kwargs = {}
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
hidden_states = self.compile_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
@@ -920,103 +921,121 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens,
|
||||
sample_indices) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
logits = self.model.compute_logits(hidden_states[sample_indices], None)
|
||||
with ProfileExecuteDuration().capture_async(
|
||||
"prepare input and forward"):
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens,
|
||||
sample_indices) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
with ProfileExecuteDuration().capture_async("post process"):
|
||||
logits = self.model.compute_logits(hidden_states[sample_indices],
|
||||
None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
bonus_logits = logits[
|
||||
spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[
|
||||
spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# NOTE: NPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Get the valid generated tokens.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
|
||||
spec_token_ids = self._get_spec_token_ids(
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# NOTE: NPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Get the valid generated tokens.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
num_scheduled_tokens,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
spec_token_ids = self._get_spec_token_ids(
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
num_scheduled_tokens,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
)
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
|
||||
durations = ProfileExecuteDuration().pop_captured_sync()
|
||||
if durations:
|
||||
dr_str = [
|
||||
f"[{tag}]:{duration:.2f}ms"
|
||||
for tag, duration in durations.items()
|
||||
]
|
||||
captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
|
||||
print(f"Profile execute duration [{captured_name}]:",
|
||||
" ".join(dr_str))
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
return model_runner_output
|
||||
|
||||
def _profile_multimodal(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user