[Perf][V1] Fully overlap model execution (#2783)
This PR is based on top of
[#23569](https://github.com/vllm-project/vllm/pull/23569) and
[#24219](https://github.com/vllm-project/vllm/pull/24219).
### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.
Expected speedup is 5-10% over the current async scheduling.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
--trust-remote-code --enforce-eager \
--distributed-executor-backend=mp \
-tp=4 \
--port 8006 \
--max-model-len 32000 \
--block-size 128 \
--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
--dataset-name random --random-input-len 2048 --random-output-len 2048 \
--ignore-eos\
--num-prompts 48 --max-concurrency 48 --request-rate inf --temperature 0 \
--metric-percentiles 90 --base-url http://localhost:8006 --save-result \
--result-dir $PROFILER_DIR
```
benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|
benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|
- vLLM version: main
- vLLM main:
e93f4cc9e3
Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com>
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import pytest
|
import pytest
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
from tests.e2e.model_utils import check_outputs_equal
|
from tests.e2e.model_utils import check_outputs_equal
|
||||||
@@ -86,3 +87,25 @@ def test_chunked_prefill_with_ascend_scheduler(
|
|||||||
name_0="vllm_output",
|
name_0="vllm_output",
|
||||||
name_1="chunked_prefill_output",
|
name_1="chunked_prefill_output",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_scheduling() -> None:
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
] * 10
|
||||||
|
sampling_params = SamplingParams(temperature=0.2,
|
||||||
|
max_tokens=10,
|
||||||
|
stop_token_ids=None)
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=50,
|
||||||
|
dtype="bfloat16",
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
async_scheduling=True,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate(prompts, sampling_params=sampling_params)
|
||||||
|
|||||||
@@ -63,8 +63,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
LogprobsTensors, ModelRunnerOutput)
|
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
@@ -156,6 +156,53 @@ def graph_capture(device: torch.device):
|
|||||||
yield graph_capture_context
|
yield graph_capture_context
|
||||||
|
|
||||||
|
|
||||||
|
# Wrapper for ModelRunnerOutput to support overlapped execution.
|
||||||
|
class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner_output: ModelRunnerOutput,
|
||||||
|
sampled_token_ids: torch.Tensor,
|
||||||
|
invalid_req_indices: list[int],
|
||||||
|
async_output_copy_stream: torch.npu.Stream,
|
||||||
|
):
|
||||||
|
self._model_runner_output = model_runner_output
|
||||||
|
self._invalid_req_indices = invalid_req_indices
|
||||||
|
|
||||||
|
# Event on the copy stream so we can synchronize the non-blocking copy.
|
||||||
|
self._async_copy_ready_event = torch.npu.Event()
|
||||||
|
|
||||||
|
# Keep a reference to the device tensor to avoid it being
|
||||||
|
# deallocated until we finish copying it to the host.
|
||||||
|
self._sampled_token_ids = sampled_token_ids
|
||||||
|
|
||||||
|
# Initiate the copy on a separate stream, but do not synchronize it.
|
||||||
|
default_stream = torch.npu.current_stream()
|
||||||
|
with torch.npu.stream(async_output_copy_stream):
|
||||||
|
async_output_copy_stream.wait_stream(default_stream)
|
||||||
|
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
|
||||||
|
'cpu', non_blocking=True)
|
||||||
|
self._async_copy_ready_event.record()
|
||||||
|
|
||||||
|
def get_output(self) -> ModelRunnerOutput:
|
||||||
|
"""Copy the device tensors to the host and return a ModelRunnerOutput.
|
||||||
|
|
||||||
|
This function blocks until the copy is finished.
|
||||||
|
"""
|
||||||
|
self._async_copy_ready_event.synchronize()
|
||||||
|
|
||||||
|
# Release the device tensor once the copy has completed
|
||||||
|
del self._sampled_token_ids
|
||||||
|
|
||||||
|
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
|
||||||
|
for i in self._invalid_req_indices:
|
||||||
|
valid_sampled_token_ids[i].clear()
|
||||||
|
|
||||||
|
output = self._model_runner_output
|
||||||
|
output.sampled_token_ids = valid_sampled_token_ids
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class NPUModelRunner(LoRAModelRunnerMixin):
|
class NPUModelRunner(LoRAModelRunnerMixin):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||||
@@ -358,6 +405,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||||
|
self.async_output_copy_stream = torch.npu.Stream() if \
|
||||||
|
self.use_async_scheduling else None
|
||||||
|
|
||||||
def _use_aclgraph(self) -> bool:
|
def _use_aclgraph(self) -> bool:
|
||||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||||
|
|
||||||
@@ -845,6 +896,76 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
return cu_num_tokens, arange
|
return cu_num_tokens, arange
|
||||||
|
|
||||||
|
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
|
||||||
|
cu_num_tokens: np.ndarray) -> None:
|
||||||
|
"""Prepare the input IDs for the current batch.
|
||||||
|
|
||||||
|
Carefully handles the `prev_sampled_token_ids` which can be cached
|
||||||
|
from the previous engine iteration, in which case those tokens on the
|
||||||
|
NPU need to be copied into the corresponding slots into input_ids."""
|
||||||
|
|
||||||
|
if self.input_batch.prev_sampled_token_ids is None:
|
||||||
|
# Normal scheduling case
|
||||||
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||||
|
self.input_ids_cpu[:total_num_scheduled_tokens],
|
||||||
|
non_blocking=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Async scheduling case, where some decode requests from the previous
|
||||||
|
# iteration won't have entries in input_ids_cpu and need to be copied
|
||||||
|
# on the NPU from prev_sampled_token_ids.
|
||||||
|
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||||
|
assert prev_req_id_to_index is not None
|
||||||
|
flattened_indices = []
|
||||||
|
prev_common_req_indices = []
|
||||||
|
indices_match = True
|
||||||
|
max_flattened_index = -1
|
||||||
|
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
||||||
|
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
||||||
|
prev_common_req_indices.append(prev_index)
|
||||||
|
# We need to compute the flattened input_ids index of the
|
||||||
|
# last token in each common request.
|
||||||
|
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||||
|
flattened_indices.append(flattened_index)
|
||||||
|
indices_match &= (prev_index == flattened_index)
|
||||||
|
max_flattened_index = max(max_flattened_index, flattened_index)
|
||||||
|
num_commmon_tokens = len(flattened_indices)
|
||||||
|
if num_commmon_tokens < total_num_scheduled_tokens:
|
||||||
|
# If not all requests are decodes from the last iteration,
|
||||||
|
# We need to copy the input_ids_cpu to the NPU first.
|
||||||
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||||
|
self.input_ids_cpu[:total_num_scheduled_tokens],
|
||||||
|
non_blocking=True)
|
||||||
|
if num_commmon_tokens == 0:
|
||||||
|
# No requests in common with the previous iteration
|
||||||
|
# So input_ids_cpu will have all the input ids.
|
||||||
|
return
|
||||||
|
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
|
||||||
|
# Common-case optimization: the batch is unchanged
|
||||||
|
# and no reordering happened.
|
||||||
|
# The indices are both the same permutation of 0..N-1 so
|
||||||
|
# we can copy directly using a single slice.
|
||||||
|
self.input_ids[:num_commmon_tokens].copy_(
|
||||||
|
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
|
||||||
|
0],
|
||||||
|
non_blocking=True)
|
||||||
|
return
|
||||||
|
# Upload the index tensors asynchronously
|
||||||
|
# so the scatter can be non-blocking.
|
||||||
|
input_ids_index_tensor = torch.tensor(flattened_indices,
|
||||||
|
dtype=torch.int64,
|
||||||
|
pin_memory=self.pin_memory).to(
|
||||||
|
self.device,
|
||||||
|
non_blocking=True)
|
||||||
|
prev_common_req_indices_tensor = torch.tensor(
|
||||||
|
prev_common_req_indices,
|
||||||
|
dtype=torch.int64,
|
||||||
|
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
|
||||||
|
self.input_ids.scatter_(dim=0,
|
||||||
|
index=input_ids_index_tensor,
|
||||||
|
src=self.input_batch.prev_sampled_token_ids[
|
||||||
|
prev_common_req_indices_tensor, 0])
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@@ -1033,6 +1154,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
attn_metadata.num_input_tokens = num_input_tokens
|
attn_metadata.num_input_tokens = num_input_tokens
|
||||||
|
|
||||||
|
# Prepare input_ids
|
||||||
|
token_indices = (positions_np +
|
||||||
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||||
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||||
|
0,
|
||||||
|
torch.from_numpy(token_indices),
|
||||||
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||||
|
# Copy the tensors to the NPU.
|
||||||
|
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather
|
# _prepare_inputs may reorder the batch, so we must gather
|
||||||
# multi-modal outputs after that to ensure the correct order
|
# multi-modal outputs after that to ensure the correct order
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
@@ -1382,11 +1513,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
2. If expert parallel is enabled, we need to consider the soc version and the
|
||||||
number of tokens. This is based on the observation that all-gather is more
|
number of tokens. This is based on the observation that all-gather is more
|
||||||
efficient than all-to-all when running on A2.
|
efficient than all-to-all when running on A2.
|
||||||
|
|
||||||
a. For A2, we choose from MC2 and all-gather.
|
a. For A2, we choose from MC2 and all-gather.
|
||||||
|
|
||||||
b. For A3, we choose from MC2 and all-to-all.
|
b. For A3, we choose from MC2 and all-to-all.
|
||||||
|
|
||||||
In both cases, we use MC2 when the number of tokens is smaller than
|
In both cases, we use MC2 when the number of tokens is smaller than
|
||||||
a its capacity threshold.
|
a its capacity threshold.
|
||||||
|
|
||||||
@@ -1424,7 +1555,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
||||||
with ProfileExecuteDuration().capture_async("prepare input"):
|
with ProfileExecuteDuration().capture_async("prepare input"):
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
@@ -1580,6 +1711,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
generator.set_offset(generator.get_offset() - 4)
|
generator.set_offset(generator.get_offset() - 4)
|
||||||
discard_sampled_tokens_req_indices.append(i)
|
discard_sampled_tokens_req_indices.append(i)
|
||||||
|
|
||||||
|
# Copy some objects so they don't get modified after returning.
|
||||||
|
# This is important when using async scheduling.
|
||||||
|
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||||
|
req_id_to_index_output_copy = \
|
||||||
|
self.input_batch.req_id_to_index.copy()
|
||||||
|
|
||||||
# NOTE: NPU -> CPU Sync happens here.
|
# NOTE: NPU -> CPU Sync happens here.
|
||||||
# Move as many CPU operations as possible before this sync point.
|
# Move as many CPU operations as possible before this sync point.
|
||||||
logprobs_tensors = sampler_output.logprobs_tensors
|
logprobs_tensors = sampler_output.logprobs_tensors
|
||||||
@@ -1592,27 +1729,52 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output,
|
scheduler_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the valid generated tokens.
|
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
if not self.use_async_scheduling:
|
||||||
if max_gen_len == 1:
|
# Get the valid generated tokens.
|
||||||
# No spec decode tokens.
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
if max_gen_len == 1:
|
||||||
|
# No spec decode tokens.
|
||||||
|
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
# Includes spec decode tokens.
|
||||||
|
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||||
|
sampled_token_ids,
|
||||||
|
self.input_batch.vocab_size,
|
||||||
|
)
|
||||||
|
# Mask out the sampled tokens that should not be sampled.
|
||||||
|
for i in discard_sampled_tokens_req_indices:
|
||||||
|
valid_sampled_token_ids[i].clear()
|
||||||
else:
|
else:
|
||||||
# Includes spec decode tokens.
|
valid_sampled_token_ids = []
|
||||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
invalid_req_indices = list(discard_sampled_tokens_req_indices)
|
||||||
sampled_token_ids,
|
invalid_req_indices_set = set(invalid_req_indices)
|
||||||
self.input_batch.vocab_size,
|
assert sampled_token_ids.shape[-1] == 1
|
||||||
)
|
|
||||||
|
|
||||||
for i in discard_sampled_tokens_req_indices:
|
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||||||
valid_sampled_token_ids[i].clear()
|
# These will be copied into input_ids in the next step
|
||||||
# Cache the sampled tokens in the model runner, so that the schedulerAdd commentMore actions
|
# when preparing inputs.
|
||||||
|
self.input_batch.prev_sampled_token_ids = \
|
||||||
|
sampled_token_ids
|
||||||
|
self.input_batch.prev_sampled_token_ids_invalid_indices = \
|
||||||
|
invalid_req_indices_set
|
||||||
|
self.input_batch.prev_req_id_to_index = {
|
||||||
|
req_id: i
|
||||||
|
for i, req_id in enumerate(self.input_batch.req_ids)
|
||||||
|
if i not in invalid_req_indices_set
|
||||||
|
}
|
||||||
|
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||||
# doesn't need to send them back.
|
# doesn't need to send them back.
|
||||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||||
# the sampled tokens back, because there's no direct communication
|
# the sampled tokens back, because there's no direct communication
|
||||||
# between the first-stage worker and the last-stage worker.
|
# between the first-stage worker and the last-stage worker.
|
||||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
for req_idx in range(num_sampled_tokens):
|
||||||
|
if self.use_async_scheduling:
|
||||||
|
sampled_ids = [-1] * 1 if \
|
||||||
|
req_idx not in invalid_req_indices_set else None
|
||||||
|
else:
|
||||||
|
sampled_ids = valid_sampled_token_ids[req_idx]
|
||||||
if not sampled_ids:
|
if not sampled_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1650,8 +1812,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
extra_args = ({"kv_connector_output": kv_connector_output})
|
extra_args = ({"kv_connector_output": kv_connector_output})
|
||||||
|
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=self.input_batch.req_ids,
|
req_ids=req_ids_output_copy,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=req_id_to_index_output_copy,
|
||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
@@ -1669,7 +1831,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logger.info("Profile execute duration [%s]:%s", captured_name,
|
logger.info("Profile execute duration [%s]:%s", captured_name,
|
||||||
" ".join(dr_str))
|
" ".join(dr_str))
|
||||||
|
|
||||||
return model_runner_output
|
if not self.use_async_scheduling:
|
||||||
|
return model_runner_output
|
||||||
|
|
||||||
|
return AsyncNPUModelRunnerOutput(
|
||||||
|
model_runner_output=model_runner_output,
|
||||||
|
sampled_token_ids=sampled_token_ids,
|
||||||
|
invalid_req_indices=invalid_req_indices,
|
||||||
|
async_output_copy_stream=self.async_output_copy_stream,
|
||||||
|
)
|
||||||
|
|
||||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||||
if self._draft_token_ids is None:
|
if self._draft_token_ids is None:
|
||||||
|
|||||||
@@ -263,6 +263,11 @@ class InputBatch:
|
|||||||
|
|
||||||
self.pooling_params: dict[str, PoolingParams] = {}
|
self.pooling_params: dict[str, PoolingParams] = {}
|
||||||
|
|
||||||
|
# Cached reference to the GPU tensor of previously sampled tokens
|
||||||
|
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
|
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||||
|
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def req_ids(self) -> list[str]:
|
def req_ids(self) -> list[str]:
|
||||||
# None elements should only be present transiently
|
# None elements should only be present transiently
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -38,8 +38,8 @@ from vllm.tasks import SupportedTask
|
|||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
ModelRunnerOutput)
|
DraftTokenIds, ModelRunnerOutput)
|
||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
@@ -191,7 +191,7 @@ class NPUWorker(WorkerBase):
|
|||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> Optional[ModelRunnerOutput]:
|
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
if not get_pp_group().is_first_rank:
|
if not get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = IntermediateTensors(
|
intermediate_tensors = IntermediateTensors(
|
||||||
@@ -220,7 +220,7 @@ class NPUWorker(WorkerBase):
|
|||||||
new_output.kv_connector_output = kv_connector_output
|
new_output.kv_connector_output = kv_connector_output
|
||||||
return new_output
|
return new_output
|
||||||
|
|
||||||
assert isinstance(output, ModelRunnerOutput)
|
assert isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user