support async mtp (#4511)

### What this PR does / why we need it?
this pr aims to support async_scheduling for mtp, which refer to vllm pr
https://github.com/vllm-project/vllm/pull/24799.
and this pr fix some synchronize problem in vllm-ascend.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Ronald
2025-12-06 17:15:57 +08:00
committed by GitHub
parent f067623afd
commit 3480094d7c
8 changed files with 477 additions and 83 deletions

View File

@@ -0,0 +1,213 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from itertools import repeat
from typing import Any
import pytest
import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.v1.metrics.reader import Metric
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B"
first_prompt = ("The following numbers of the sequence " +
", ".join(str(i) for i in range(10)) + " are:")
example_prompts = [first_prompt, "In one word, the capital of France is "
] + [f"Tell me about the number {i}: " for i in range(32)]
default_params = dict(
temperature=0.0, # greedy
max_tokens=23,
min_tokens=18,
)
def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, prefill chunking."""
test_sampling_params: list[dict[str, Any]] = [
dict(),
]
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
(False, "mp", False, None, False),
(False, "mp", True, None, False),
(False, "uni", True, None, False),
]
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
@dynamo_config.patch(cache_size_limit=16)
def run_tests(
monkeypatch: pytest.MonkeyPatch,
model: str,
test_configs: list[tuple],
test_sampling_params: list[dict[str, Any]],
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""
with monkeypatch.context():
# avoid precision errors
outputs: list[tuple[str, list, list]] = []
for n, (
test_preemption,
executor,
async_scheduling,
spec_config,
test_prefill_chunking,
) in enumerate(test_configs, 1):
test_str = f"{n}/{len(test_configs)}"
test_results = run_test(
model,
test_str,
test_sampling_params,
test_preemption,
executor,
async_scheduling,
spec_config,
test_prefill_chunking=test_prefill_chunking,
)
outputs.append(test_results)
baseline_config, baseline_tests, _ = outputs[0]
_, _, baseline_acceptances = next((o for o in outputs if o[2] is not None),
(None, None, None))
print(
f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}"
)
failure = None
for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
for base_outs, base_acceptance_rate, test_outs, test_acceptance_rate, params in zip(
baseline_tests,
baseline_acceptances or repeat(None),
test_outputs,
test_acceptance_rates or repeat(None),
test_sampling_params,
):
try:
check_outputs_equal(
outputs_0_lst=base_outs,
outputs_1_lst=test_outs,
name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}",
)
if (base_acceptance_rate is not None
and test_acceptance_rate is not None):
if "spec_mml=None" in test_config:
assert (test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate == pytest.approx(
base_acceptance_rate, rel=5e-2))
else:
# Currently the reported acceptance rate is expected to be
# lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.1
print(f"PASSED: config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}")
except AssertionError as e:
print(f"FAILED: config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}")
if failure is None:
failure = e
if failure is not None:
raise failure
def run_test(
model: str,
test_str: str,
sampling_param_tests: list[dict[str, Any]],
test_preemption: bool,
executor: str,
async_scheduling: bool,
spec_config: dict[str, Any] | None,
test_prefill_chunking: bool,
):
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
# Force preemptions
dict(num_gpu_blocks_override=2) if test_preemption else dict(
gpu_memory_utilization=0.9))
spec_mml = (spec_config or {}).get("max_model_len")
test_config = (f"executor={executor}, preemption={test_preemption}, "
f"async_sched={async_scheduling}, "
f"chunk_prefill={test_prefill_chunking}, "
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}")
print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80)
with VllmRunner(
model,
max_model_len=512,
enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,
enforce_eager=True,
async_scheduling=async_scheduling,
distributed_executor_backend=executor,
dtype="float16", # avoid precision errors
speculative_config=spec_config,
disable_log_stats=False,
**cache_arg,
) as vllm_model:
results = []
acceptance_rates: list[float] | None = [] if spec_decoding else None
for override_params in sampling_param_tests:
metrics_before = vllm_model.model.get_metrics()
print(f"----------- RUNNING PARAMS: {override_params}")
results.append(
vllm_model.generate(
example_prompts,
sampling_params=SamplingParams(**default_params,
**override_params),
))
metrics_after = vllm_model.model.get_metrics()
if acceptance_rates is not None:
acceptance_rate = _get_acceptance_rate(metrics_before,
metrics_after)
acceptance_rates.append(acceptance_rate)
print(f"ACCEPTANCE RATE {acceptance_rate}")
if test_preemption:
preemptions = _get_count(metrics_before, metrics_after,
"vllm:num_preemptions")
assert preemptions > 0, "preemption test had no preemptions"
if len(results) > 1:
# First check that the different parameter configs
# actually result in different output.
for other_test_outs, params in zip(results[1:],
sampling_param_tests[1:]):
with pytest.raises(AssertionError):
check_outputs_equal(
outputs_0_lst=results[0][0],
outputs_1_lst=other_test_outs,
name_0=f"baseline params={params}",
name_1=f"other params={params}",
)
return test_config, results, acceptance_rates
def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
return accept / draft if draft > 0 else 0.0
def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
before_val = next(m.value for m in before if m.name == name)
after_val = next(m.value for m in after if m.name == name)
return after_val - before_val

View File

@@ -87,6 +87,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
self.mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
self.mock_device = 'cpu:0'
torch.Tensor.pin_memory = lambda x: x # noqa
self.builder = AscendAttentionMetadataBuilder(None, None,
self.mock_vllm_config,
self.mock_device)

View File

@@ -299,6 +299,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
torch.Tensor.pin_memory = lambda x: x # noqa
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
@@ -534,6 +535,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
mock_get_pcp_group):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
@@ -599,6 +601,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
mock_get_pcp_group):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
@@ -660,6 +663,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
@@ -713,6 +718,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
@@ -767,6 +774,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group

View File

@@ -317,8 +317,8 @@ class AscendAttentionMetadataBuilder:
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
])
query_start_loc = query_start_loc_cpu.to(self.device,
non_blocking=True)
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,

View File

@@ -556,35 +556,43 @@ class AscendMLAMetadataBuilder:
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata = \
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=local_chunk_starts.to(device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=local_chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(
dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
device, non_blocking=True
),
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
.tolist(),
local_context_lens_allranks=local_context_lens_allranks
.tolist(),
padded_local_cu_seq_lens=
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
else:
chunked_context_metadata = \
chunked_context_metadata = (
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(
dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
))
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
@@ -616,7 +624,8 @@ class AscendMLAMetadataBuilder:
cos = common_attn_metadata.cos
sin = common_attn_metadata.sin
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]

View File

@@ -142,6 +142,9 @@ class MtpProposer(Proposer):
self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32)
self.arange_cpu = torch.arange(max_num_slots_for_arange,
device="cpu",
dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
@@ -157,6 +160,7 @@ class MtpProposer(Proposer):
)
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
@@ -351,6 +355,8 @@ class MtpProposer(Proposer):
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
self._copy_valid_sampled_token_count(next_token_ids,
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size > 1:
@@ -430,6 +436,28 @@ class MtpProposer(Proposer):
return draft_token_ids
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor,
valid_sampled_tokens_count: torch.Tensor) -> None:
if self.runner.valid_sampled_token_count_event is not None:
default_stream = torch.npu.current_stream()
# initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.npu.stream(
self.runner.valid_sampled_token_count_copy_stream):
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
default_stream) # type: ignore
self.runner.valid_sampled_token_count_cpu[:
valid_sampled_tokens_count
.shape[0]].copy_(
valid_sampled_tokens_count,
non_blocking=True
)
self.runner.valid_sampled_token_count_event.record()
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
1)
def _init_mtp_model(self):
architecture = self.vllm_config.model_config.architecture
target_device = self.vllm_config.device_config.device
@@ -696,6 +724,11 @@ class MtpProposer(Proposer):
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
if self.use_async_scheduling:
# there is synchronize between mtp steps when enable aclgraph,
# disable aclgraph when use async scheduling to avoid the
# synchronize overhead.
aclgraph_runtime_mode = CUDAGraphMode.NONE
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
@@ -822,7 +855,7 @@ class MtpProposer(Proposer):
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
aclgraph_runtime_mode != CUDAGraphMode.FULL):
decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
decode_metadata.actual_seq_lengths_q = self.arange_cpu[
1:batch_size + 1].tolist()
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
decode_metadata.actual_seq_lengths_q = \
@@ -847,7 +880,9 @@ class MtpProposer(Proposer):
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions[:batch_size])
# Increment the sequence lengths.
attn_metadata_i.seq_lens[:batch_size] += 1
# This is an out-of-place operation to avoid modifying the original tensor
# when enable async_scheduling.
attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
exceeds_max_model_len_cpu = exceeds_max_model_len.to(

View File

@@ -97,6 +97,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
make_empty_encoder_model_runner_output)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
@@ -213,6 +214,7 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
sampled_token_ids: torch.Tensor,
invalid_req_indices: list[int],
async_output_copy_stream: torch.npu.Stream,
vocab_size: int,
):
self._model_runner_output = model_runner_output
self._invalid_req_indices = invalid_req_indices
@@ -223,7 +225,7 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids
self.vocab_size = vocab_size
# Initiate the copy on a separate stream, but do not synchronize it.
default_stream = torch.npu.current_stream()
with torch.npu.stream(async_output_copy_stream):
@@ -242,10 +244,17 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
# Release the device tensor once the copy has completed
del self._sampled_token_ids
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
max_gen_len = self._sampled_token_ids_cpu.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
else:
valid_sampled_token_ids, _ = RejectionSampler.parse_output(
self._sampled_token_ids_cpu,
self.vocab_size,
self._invalid_req_indices,
return_cu_num_tokens=False)
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
return output
@@ -567,6 +576,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.async_output_copy_stream = torch.npu.Stream() if \
self.use_async_scheduling else None
self.num_spec_tokens = 0
if self.speculative_config:
self.num_spec_tokens = self.speculative_config.num_speculative_tokens # noqa
self.valid_sampled_token_count_event: torch.npu.Event | None = None
self.valid_sampled_token_count_copy_stream: torch.npu.Stream | None = None
if self.use_async_scheduling and self.num_spec_tokens:
self.valid_sampled_token_count_event = torch.npu.Event()
self.valid_sampled_token_count_copy_stream = torch.npu.Stream()
self.valid_sampled_token_count_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
@@ -791,13 +814,40 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
# wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count()
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i]
req_index = self.input_batch.req_id_to_index.get(req_id)
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step doesn't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
if req_state.prev_num_draft_len:
if req_index is None:
req_state.prev_num_draft_len = 0
else:
assert self.input_batch.prev_req_id_to_index is not None
prev_req_index = self.input_batch.prev_req_id_to_index[
req_id]
num_accepted = valid_sampled_token_count[prev_req_index] - 1
num_rejected = req_state.prev_num_draft_len - num_accepted
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
@@ -828,12 +878,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
# The request was either preempted and resumed later, or was
# not scheduled in the previous step and needs to be added
# again.
if self.use_async_scheduling and num_output_tokens > 0:
# We must recover the output token ids for resumed requests
# in the async scheduling case, so that correct input_ids
# are obtained.
resumed_token_ids = req_data.all_token_ids[req_id]
req_state.output_token_ids = resumed_token_ids[
-num_output_tokens:]
req_ids_to_add.append(req_id)
continue
@@ -860,8 +918,10 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if spec_token_ids:
num_spec_tokens = len(spec_token_ids)
num_spec_tokens = len(spec_token_ids)
if self.use_async_scheduling:
req_state.prev_num_draft_len = num_spec_tokens
if num_spec_tokens:
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
@@ -882,6 +942,17 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _get_valid_sampled_token_count(self) -> list[int]:
# Wait until valid_sampled_tokens_count is copied to cpu,
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
if (self.valid_sampled_token_count_event is None
or prev_sampled_token_ids is None):
return []
counts_cpu = self.valid_sampled_token_count_cpu
self.valid_sampled_token_count_event.synchronize()
return counts_cpu[:prev_sampled_token_ids.shape[0]].tolist()
def _init_mrope_positions(self, req_state: CachedRequestState):
assert supports_mrope(self.model), "MROPE is not supported"
req_state.mrope_positions, req_state.mrope_position_delta = \
@@ -901,26 +972,25 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# immediately once the other two flags are no longer needed.
if self.dp_size == 1:
return num_tokens, None, with_prefill
# Sync num_tokens, with_prefill across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
],
dtype=torch.int32,
device="npu")
device="cpu")
flags_tensor = torch.tensor([int(with_prefill)],
dtype=torch.int32,
device="npu")
device="cpu")
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
# use cpu_group to avoid cpu synchronization issue.
# it can be overlapped with main moell execution on npu.
dist.all_reduce(packed_tensor, group=get_dp_group().cpu_group)
# Unpack the results
num_tokens_across_dp = packed_tensor[:-1]
synced_flags = packed_tensor[-1:]
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
global_with_prefill = bool(synced_flags[0])
@@ -1195,7 +1265,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
return cu_num_tokens, arange
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
def _prepare_input_ids(self, scheduler_output: "SchedulerOutput",
total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray) -> None:
"""Prepare the input IDs for the current batch.
@@ -1218,21 +1289,44 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# on the NPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None
flattened_indices = []
prev_common_req_indices = []
sample_flattened_indices: list[int] = []
spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = []
indices_match = True
max_flattened_index = -1
total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the
# last token in each common request.
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
flattened_indices.append(flattened_index)
indices_match &= (prev_index == flattened_index)
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
# sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(
range(flattened_index - draft_len + 1,
flattened_index + 1))
start = prev_index * self.num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
# flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start,
start + draft_len))
indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
num_commmon_tokens = len(flattened_indices)
if num_commmon_tokens < total_num_scheduled_tokens:
num_commmon_tokens = len(sample_flattened_indices)
total_without_spec = (total_num_scheduled_tokens -
total_num_spec_tokens)
if num_commmon_tokens < total_without_spec:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the NPU first.
self.input_ids[:total_num_scheduled_tokens].copy_(
@@ -1256,21 +1350,45 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
non_blocking=True)
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
input_ids_index_tensor = torch.tensor(flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(
self.device,
non_blocking=True)
# Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(
sample_flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
self.input_ids.scatter_(dim=0,
index=input_ids_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0])
self.input_ids.scatter_(
dim=0,
index=sampled_tokens_index_tensor,
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0],
)
# scatter the draft tokens after the sampled tokens are scattered.
if self._draft_token_ids is None or not spec_flattened_indices:
return
assert isinstance(self._draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(
spec_flattened_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices,
dtype=torch.int64,
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
# because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here.
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
self._draft_token_ids = None
self.input_ids.scatter_(
dim=0,
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
)
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
@@ -1544,7 +1662,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Copy the tensors to the NPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens,
cu_num_tokens)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
@@ -1993,8 +2112,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
cu_num_scheduled_tokens - num_sampled_tokens,
num_sampled_tokens)
logits_indices_pcp += arange
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).to(
self.device, non_blocking=True)
logits_indices_pcp = torch.from_numpy(
logits_indices_pcp).pin_memory().to(self.device,
non_blocking=True)
# Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1
@@ -2015,16 +2135,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
target_logits_indices += arange
# TODO: Optimize the CPU -> NPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
cu_num_draft_tokens = (
torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
self.device, non_blocking=True))
cu_num_sampled_tokens = (
torch.from_numpy(cu_num_sampled_tokens).pin_memory().to(
self.device, non_blocking=True))
logits_indices = (torch.from_numpy(logits_indices).pin_memory().to(
self.device, non_blocking=True))
target_logits_indices = (
torch.from_numpy(target_logits_indices).pin_memory().to(
self.device, non_blocking=True))
bonus_logits_indices = torch.from_numpy(
bonus_logits_indices).pin_memory().to(self.device,
non_blocking=True)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
@@ -2466,7 +2590,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
sampler_output.sampled_token_ids = output_token_ids
if self.need_accepted_tokens:
self._update_states_after_model_execute(output_token_ids)
discard_sampled_tokens_req_indices = \
self.discard_request_indices.np[:self.num_discarded_requests]
for i in discard_sampled_tokens_req_indices:
@@ -2494,6 +2617,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -2514,13 +2638,14 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
)
invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids.shape[-1] == 1
if self.num_spec_tokens <= 0:
assert sampled_token_ids.shape[-1] == 1
# Cache the sampled tokens on the NPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = sampled_token_ids
# Cache the sampled tokens on the NPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
self.input_batch.prev_sampled_token_ids = \
sampled_token_ids
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
self.input_batch.prev_req_id_to_index = {
@@ -2629,6 +2754,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
sampled_token_ids=sampled_token_ids,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
vocab_size=self.input_batch.vocab_size,
)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:

View File

@@ -68,6 +68,8 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
prev_num_draft_len: int = 0 # previous number of draft tokens
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)