support async mtp (#4511)
### What this PR does / why we need it?
this pr aims to support async_scheduling for mtp, which refer to vllm pr
https://github.com/vllm-project/vllm/pull/24799.
and this pr fix some synchronize problem in vllm-ascend.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
213
tests/e2e/singlecard/test_async_scheduling.py
Normal file
213
tests/e2e/singlecard/test_async_scheduling.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from itertools import repeat
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch._dynamo.config as dynamo_config
|
||||
from vllm import SamplingParams
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.model_utils import check_outputs_equal
|
||||
|
||||
MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
first_prompt = ("The following numbers of the sequence " +
|
||||
", ".join(str(i) for i in range(10)) + " are:")
|
||||
example_prompts = [first_prompt, "In one word, the capital of France is "
|
||||
] + [f"Tell me about the number {i}: " for i in range(32)]
|
||||
|
||||
default_params = dict(
|
||||
temperature=0.0, # greedy
|
||||
max_tokens=23,
|
||||
min_tokens=18,
|
||||
)
|
||||
|
||||
|
||||
def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ):
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor, prefill chunking."""
|
||||
test_sampling_params: list[dict[str, Any]] = [
|
||||
dict(),
|
||||
]
|
||||
|
||||
# test_preemption, executor, async_scheduling,
|
||||
# spec_config, test_prefill_chunking
|
||||
test_configs = [
|
||||
(False, "mp", False, None, False),
|
||||
(False, "mp", True, None, False),
|
||||
(False, "uni", True, None, False),
|
||||
]
|
||||
|
||||
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
|
||||
|
||||
|
||||
@dynamo_config.patch(cache_size_limit=16)
|
||||
def run_tests(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model: str,
|
||||
test_configs: list[tuple],
|
||||
test_sampling_params: list[dict[str, Any]],
|
||||
):
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor with spec decoding."""
|
||||
|
||||
with monkeypatch.context():
|
||||
# avoid precision errors
|
||||
outputs: list[tuple[str, list, list]] = []
|
||||
for n, (
|
||||
test_preemption,
|
||||
executor,
|
||||
async_scheduling,
|
||||
spec_config,
|
||||
test_prefill_chunking,
|
||||
) in enumerate(test_configs, 1):
|
||||
test_str = f"{n}/{len(test_configs)}"
|
||||
test_results = run_test(
|
||||
model,
|
||||
test_str,
|
||||
test_sampling_params,
|
||||
test_preemption,
|
||||
executor,
|
||||
async_scheduling,
|
||||
spec_config,
|
||||
test_prefill_chunking=test_prefill_chunking,
|
||||
)
|
||||
outputs.append(test_results)
|
||||
|
||||
baseline_config, baseline_tests, _ = outputs[0]
|
||||
_, _, baseline_acceptances = next((o for o in outputs if o[2] is not None),
|
||||
(None, None, None))
|
||||
|
||||
print(
|
||||
f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}"
|
||||
)
|
||||
|
||||
failure = None
|
||||
for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
|
||||
for base_outs, base_acceptance_rate, test_outs, test_acceptance_rate, params in zip(
|
||||
baseline_tests,
|
||||
baseline_acceptances or repeat(None),
|
||||
test_outputs,
|
||||
test_acceptance_rates or repeat(None),
|
||||
test_sampling_params,
|
||||
):
|
||||
try:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=base_outs,
|
||||
outputs_1_lst=test_outs,
|
||||
name_0=f"baseline=[{baseline_config}], params={params}",
|
||||
name_1=f"config=[{test_config}], params={params}",
|
||||
)
|
||||
|
||||
if (base_acceptance_rate is not None
|
||||
and test_acceptance_rate is not None):
|
||||
if "spec_mml=None" in test_config:
|
||||
assert (test_acceptance_rate > base_acceptance_rate
|
||||
or test_acceptance_rate == pytest.approx(
|
||||
base_acceptance_rate, rel=5e-2))
|
||||
else:
|
||||
# Currently the reported acceptance rate is expected to be
|
||||
# lower when we sometimes skip drafting altogether.
|
||||
assert test_acceptance_rate > 0.1
|
||||
print(f"PASSED: config=[{test_config}], params={params}"
|
||||
f" accept_rate={test_acceptance_rate}")
|
||||
except AssertionError as e:
|
||||
print(f"FAILED: config=[{test_config}], params={params}"
|
||||
f" accept_rate={test_acceptance_rate}")
|
||||
if failure is None:
|
||||
failure = e
|
||||
|
||||
if failure is not None:
|
||||
raise failure
|
||||
|
||||
|
||||
def run_test(
|
||||
model: str,
|
||||
test_str: str,
|
||||
sampling_param_tests: list[dict[str, Any]],
|
||||
test_preemption: bool,
|
||||
executor: str,
|
||||
async_scheduling: bool,
|
||||
spec_config: dict[str, Any] | None,
|
||||
test_prefill_chunking: bool,
|
||||
):
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
spec_decoding = spec_config is not None
|
||||
cache_arg: dict[str, Any] = (
|
||||
# Force preemptions
|
||||
dict(num_gpu_blocks_override=2) if test_preemption else dict(
|
||||
gpu_memory_utilization=0.9))
|
||||
spec_mml = (spec_config or {}).get("max_model_len")
|
||||
test_config = (f"executor={executor}, preemption={test_preemption}, "
|
||||
f"async_sched={async_scheduling}, "
|
||||
f"chunk_prefill={test_prefill_chunking}, "
|
||||
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}")
|
||||
print("-" * 80)
|
||||
print(f"---- TESTING {test_str}: {test_config}")
|
||||
print("-" * 80)
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
enable_chunked_prefill=test_prefill_chunking,
|
||||
# Force prefill chunking
|
||||
max_num_batched_tokens=48 if test_prefill_chunking else None,
|
||||
enforce_eager=True,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=executor,
|
||||
dtype="float16", # avoid precision errors
|
||||
speculative_config=spec_config,
|
||||
disable_log_stats=False,
|
||||
**cache_arg,
|
||||
) as vllm_model:
|
||||
results = []
|
||||
acceptance_rates: list[float] | None = [] if spec_decoding else None
|
||||
for override_params in sampling_param_tests:
|
||||
metrics_before = vllm_model.model.get_metrics()
|
||||
print(f"----------- RUNNING PARAMS: {override_params}")
|
||||
results.append(
|
||||
vllm_model.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(**default_params,
|
||||
**override_params),
|
||||
))
|
||||
metrics_after = vllm_model.model.get_metrics()
|
||||
if acceptance_rates is not None:
|
||||
acceptance_rate = _get_acceptance_rate(metrics_before,
|
||||
metrics_after)
|
||||
acceptance_rates.append(acceptance_rate)
|
||||
print(f"ACCEPTANCE RATE {acceptance_rate}")
|
||||
|
||||
if test_preemption:
|
||||
preemptions = _get_count(metrics_before, metrics_after,
|
||||
"vllm:num_preemptions")
|
||||
assert preemptions > 0, "preemption test had no preemptions"
|
||||
|
||||
if len(results) > 1:
|
||||
# First check that the different parameter configs
|
||||
# actually result in different output.
|
||||
for other_test_outs, params in zip(results[1:],
|
||||
sampling_param_tests[1:]):
|
||||
with pytest.raises(AssertionError):
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=results[0][0],
|
||||
outputs_1_lst=other_test_outs,
|
||||
name_0=f"baseline params={params}",
|
||||
name_1=f"other params={params}",
|
||||
)
|
||||
|
||||
return test_config, results, acceptance_rates
|
||||
|
||||
|
||||
def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
|
||||
draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
|
||||
accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
|
||||
return accept / draft if draft > 0 else 0.0
|
||||
|
||||
|
||||
def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
|
||||
before_val = next(m.value for m in before if m.name == name)
|
||||
after_val = next(m.value for m in after if m.name == name)
|
||||
return after_val - before_val
|
||||
@@ -87,6 +87,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
|
||||
self.mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
self.mock_device = 'cpu:0'
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
self.builder = AscendAttentionMetadataBuilder(None, None,
|
||||
self.mock_vllm_config,
|
||||
self.mock_device)
|
||||
|
||||
@@ -299,6 +299,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
@@ -534,6 +535,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
mock_get_pcp_group):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
@@ -599,6 +601,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
mock_get_pcp_group):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
@@ -660,6 +663,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
@@ -713,6 +718,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
@@ -767,6 +774,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
@@ -317,8 +317,8 @@ class AscendAttentionMetadataBuilder:
|
||||
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
|
||||
])
|
||||
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
|
||||
@@ -556,35 +556,43 @@ class AscendMLAMetadataBuilder:
|
||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
chunked_context_metadata = \
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=local_chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
starts=local_chunk_starts.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(
|
||||
dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
||||
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
|
||||
device, non_blocking=True
|
||||
),
|
||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.
|
||||
npu(),
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens
|
||||
.tolist(),
|
||||
local_context_lens_allranks=local_context_lens_allranks
|
||||
.tolist(),
|
||||
padded_local_cu_seq_lens=
|
||||
padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata = (
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
starts=chunk_starts.pin_memory().to(
|
||||
device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(
|
||||
dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
))
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
@@ -616,7 +624,8 @@ class AscendMLAMetadataBuilder:
|
||||
cos = common_attn_metadata.cos
|
||||
sin = common_attn_metadata.sin
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
|
||||
1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
|
||||
@@ -142,6 +142,9 @@ class MtpProposer(Proposer):
|
||||
self.arange = torch.arange(max_num_slots_for_arange,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.arange_cpu = torch.arange(max_num_slots_for_arange,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
@@ -157,6 +160,7 @@ class MtpProposer(Proposer):
|
||||
)
|
||||
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
def load_model(self, model) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
@@ -351,6 +355,8 @@ class MtpProposer(Proposer):
|
||||
self.runner.discard_request_indices.gpu,
|
||||
self.runner.num_discarded_requests
|
||||
)
|
||||
self._copy_valid_sampled_token_count(next_token_ids,
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size > 1:
|
||||
@@ -430,6 +436,28 @@ class MtpProposer(Proposer):
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def _copy_valid_sampled_token_count(
|
||||
self, next_token_ids: torch.Tensor,
|
||||
valid_sampled_tokens_count: torch.Tensor) -> None:
|
||||
if self.runner.valid_sampled_token_count_event is not None:
|
||||
default_stream = torch.npu.current_stream()
|
||||
# initialize a new stream to overlap the copy operation with
|
||||
# prepare_input of draft model.
|
||||
with torch.npu.stream(
|
||||
self.runner.valid_sampled_token_count_copy_stream):
|
||||
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
|
||||
default_stream) # type: ignore
|
||||
self.runner.valid_sampled_token_count_cpu[:
|
||||
valid_sampled_tokens_count
|
||||
.shape[0]].copy_(
|
||||
valid_sampled_tokens_count,
|
||||
non_blocking=True
|
||||
)
|
||||
self.runner.valid_sampled_token_count_event.record()
|
||||
|
||||
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
|
||||
1)
|
||||
|
||||
def _init_mtp_model(self):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
target_device = self.vllm_config.device_config.device
|
||||
@@ -696,6 +724,11 @@ class MtpProposer(Proposer):
|
||||
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
|
||||
aclgraph_runtime_mode, batch_descriptor = \
|
||||
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||||
if self.use_async_scheduling:
|
||||
# there is synchronize between mtp steps when enable aclgraph,
|
||||
# disable aclgraph when use async scheduling to avoid the
|
||||
# synchronize overhead.
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
@@ -822,7 +855,7 @@ class MtpProposer(Proposer):
|
||||
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
|
||||
if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
|
||||
aclgraph_runtime_mode != CUDAGraphMode.FULL):
|
||||
decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||
decode_metadata.actual_seq_lengths_q = self.arange_cpu[
|
||||
1:batch_size + 1].tolist()
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
decode_metadata.actual_seq_lengths_q = \
|
||||
@@ -847,7 +880,9 @@ class MtpProposer(Proposer):
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions[:batch_size])
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata_i.seq_lens[:batch_size] += 1
|
||||
# This is an out-of-place operation to avoid modifying the original tensor
|
||||
# when enable async_scheduling.
|
||||
attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||
|
||||
@@ -97,6 +97,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
make_empty_encoder_model_runner_output)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
@@ -213,6 +214,7 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
sampled_token_ids: torch.Tensor,
|
||||
invalid_req_indices: list[int],
|
||||
async_output_copy_stream: torch.npu.Stream,
|
||||
vocab_size: int,
|
||||
):
|
||||
self._model_runner_output = model_runner_output
|
||||
self._invalid_req_indices = invalid_req_indices
|
||||
@@ -223,7 +225,7 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
# Keep a reference to the device tensor to avoid it being
|
||||
# deallocated until we finish copying it to the host.
|
||||
self._sampled_token_ids = sampled_token_ids
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
# Initiate the copy on a separate stream, but do not synchronize it.
|
||||
default_stream = torch.npu.current_stream()
|
||||
with torch.npu.stream(async_output_copy_stream):
|
||||
@@ -242,10 +244,17 @@ class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
# Release the device tensor once the copy has completed
|
||||
del self._sampled_token_ids
|
||||
|
||||
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
|
||||
for i in self._invalid_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
max_gen_len = self._sampled_token_ids_cpu.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
|
||||
for i in self._invalid_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
else:
|
||||
valid_sampled_token_ids, _ = RejectionSampler.parse_output(
|
||||
self._sampled_token_ids_cpu,
|
||||
self.vocab_size,
|
||||
self._invalid_req_indices,
|
||||
return_cu_num_tokens=False)
|
||||
output = self._model_runner_output
|
||||
output.sampled_token_ids = valid_sampled_token_ids
|
||||
return output
|
||||
@@ -567,6 +576,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.async_output_copy_stream = torch.npu.Stream() if \
|
||||
self.use_async_scheduling else None
|
||||
self.num_spec_tokens = 0
|
||||
if self.speculative_config:
|
||||
self.num_spec_tokens = self.speculative_config.num_speculative_tokens # noqa
|
||||
self.valid_sampled_token_count_event: torch.npu.Event | None = None
|
||||
self.valid_sampled_token_count_copy_stream: torch.npu.Stream | None = None
|
||||
if self.use_async_scheduling and self.num_spec_tokens:
|
||||
self.valid_sampled_token_count_event = torch.npu.Event()
|
||||
self.valid_sampled_token_count_copy_stream = torch.npu.Stream()
|
||||
self.valid_sampled_token_count_cpu = torch.empty(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
# Input Batch
|
||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||
# `initialize_kv_cache` based on the kv cache config. However, as in
|
||||
@@ -791,13 +814,40 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# Update the states of the running/resumed requests.
|
||||
is_last_rank = get_pp_group().is_last_rank
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
# wait until valid_sampled_tokens_count is copied to cpu,
|
||||
# then use it to update actual num_computed_tokens of each request.
|
||||
valid_sampled_token_count = self._get_valid_sampled_token_count()
|
||||
for i, req_id in enumerate(req_data.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
|
||||
# Update the cached states.
|
||||
resumed_from_preemption = req_id in req_data.resumed_req_ids
|
||||
num_output_tokens = req_data.num_output_tokens[i]
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
# prev_num_draft_len is used in async scheduling mode with
|
||||
# spec decode. it indicates if need to update num_computed_tokens
|
||||
# of the request. for example:
|
||||
# fist step: num_computed_tokens = 0, spec_tokens = [],
|
||||
# prev_num_draft_len = 0.
|
||||
# second step: num_computed_tokens = 100(prompt length),
|
||||
# spec_tokens = [a,b], prev_num_draft_len = 0.
|
||||
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
|
||||
# prev_num_draft_len = 2.
|
||||
# num_computed_tokens in first step and second step doesn't contain
|
||||
# the spec tokens length, but in third step it contains the
|
||||
# spec tokens length. we only need to update num_computed_tokens
|
||||
# when prev_num_draft_len > 0.
|
||||
if req_state.prev_num_draft_len:
|
||||
if req_index is None:
|
||||
req_state.prev_num_draft_len = 0
|
||||
else:
|
||||
assert self.input_batch.prev_req_id_to_index is not None
|
||||
prev_req_index = self.input_batch.prev_req_id_to_index[
|
||||
req_id]
|
||||
num_accepted = valid_sampled_token_count[prev_req_index] - 1
|
||||
num_rejected = req_state.prev_num_draft_len - num_accepted
|
||||
num_computed_tokens -= num_rejected
|
||||
req_state.output_token_ids.extend([-1] * num_accepted)
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
|
||||
if not is_last_rank:
|
||||
@@ -828,12 +878,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = new_block_ids
|
||||
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
# scheduled in the previous step and needs to be added again.
|
||||
# The request was either preempted and resumed later, or was
|
||||
# not scheduled in the previous step and needs to be added
|
||||
# again.
|
||||
|
||||
if self.use_async_scheduling and num_output_tokens > 0:
|
||||
# We must recover the output token ids for resumed requests
|
||||
# in the async scheduling case, so that correct input_ids
|
||||
# are obtained.
|
||||
resumed_token_ids = req_data.all_token_ids[req_id]
|
||||
req_state.output_token_ids = resumed_token_ids[
|
||||
-num_output_tokens:]
|
||||
|
||||
req_ids_to_add.append(req_id)
|
||||
continue
|
||||
|
||||
@@ -860,8 +918,10 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||||
if spec_token_ids:
|
||||
num_spec_tokens = len(spec_token_ids)
|
||||
num_spec_tokens = len(spec_token_ids)
|
||||
if self.use_async_scheduling:
|
||||
req_state.prev_num_draft_len = num_spec_tokens
|
||||
if num_spec_tokens:
|
||||
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
||||
end_token_index = start_index + num_spec_tokens
|
||||
self.input_batch.token_ids_cpu[
|
||||
@@ -882,6 +942,17 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# Refresh batch metadata with any pending updates.
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
def _get_valid_sampled_token_count(self) -> list[int]:
|
||||
# Wait until valid_sampled_tokens_count is copied to cpu,
|
||||
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
|
||||
if (self.valid_sampled_token_count_event is None
|
||||
or prev_sampled_token_ids is None):
|
||||
return []
|
||||
|
||||
counts_cpu = self.valid_sampled_token_count_cpu
|
||||
self.valid_sampled_token_count_event.synchronize()
|
||||
return counts_cpu[:prev_sampled_token_ids.shape[0]].tolist()
|
||||
|
||||
def _init_mrope_positions(self, req_state: CachedRequestState):
|
||||
assert supports_mrope(self.model), "MROPE is not supported"
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
@@ -901,26 +972,25 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# immediately once the other two flags are no longer needed.
|
||||
if self.dp_size == 1:
|
||||
return num_tokens, None, with_prefill
|
||||
|
||||
# Sync num_tokens, with_prefill across dp ranks
|
||||
num_tokens_tensor = torch.tensor([
|
||||
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
device="cpu")
|
||||
|
||||
flags_tensor = torch.tensor([int(with_prefill)],
|
||||
dtype=torch.int32,
|
||||
device="npu")
|
||||
device="cpu")
|
||||
|
||||
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
|
||||
|
||||
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
|
||||
# use cpu_group to avoid cpu synchronization issue.
|
||||
# it can be overlapped with main moell execution on npu.
|
||||
dist.all_reduce(packed_tensor, group=get_dp_group().cpu_group)
|
||||
|
||||
# Unpack the results
|
||||
num_tokens_across_dp = packed_tensor[:-1]
|
||||
synced_flags = packed_tensor[-1:]
|
||||
|
||||
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
|
||||
global_with_prefill = bool(synced_flags[0])
|
||||
|
||||
@@ -1195,7 +1265,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
|
||||
return cu_num_tokens, arange
|
||||
|
||||
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
|
||||
def _prepare_input_ids(self, scheduler_output: "SchedulerOutput",
|
||||
total_num_scheduled_tokens: int,
|
||||
cu_num_tokens: np.ndarray) -> None:
|
||||
"""Prepare the input IDs for the current batch.
|
||||
|
||||
@@ -1218,21 +1289,44 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# on the NPU from prev_sampled_token_ids.
|
||||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||
assert prev_req_id_to_index is not None
|
||||
flattened_indices = []
|
||||
prev_common_req_indices = []
|
||||
sample_flattened_indices: list[int] = []
|
||||
spec_flattened_indices: list[int] = []
|
||||
prev_common_req_indices: list[int] = []
|
||||
prev_draft_token_indices: list[int] = []
|
||||
indices_match = True
|
||||
max_flattened_index = -1
|
||||
total_num_spec_tokens = 0
|
||||
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
||||
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
||||
prev_common_req_indices.append(prev_index)
|
||||
# We need to compute the flattened input_ids index of the
|
||||
# last token in each common request.
|
||||
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
|
||||
total_num_spec_tokens += draft_len
|
||||
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||
flattened_indices.append(flattened_index)
|
||||
indices_match &= (prev_index == flattened_index)
|
||||
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
|
||||
# sample_flattened_indices = [0, 2, 5]
|
||||
# spec_flattened_indices = [1, 3, 4, 6, 7]
|
||||
sample_flattened_indices.append(flattened_index - draft_len)
|
||||
spec_flattened_indices.extend(
|
||||
range(flattened_index - draft_len + 1,
|
||||
flattened_index + 1))
|
||||
start = prev_index * self.num_spec_tokens
|
||||
# prev_draft_token_indices is used to find which draft_tokens_id
|
||||
# should be copied to input_ids
|
||||
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
|
||||
# flatten draft_tokens_id [1,2,3,4,5,6]
|
||||
# draft_len of each request [1, 2, 1]
|
||||
# then prev_draft_token_indices is [0, 2, 3, 4]
|
||||
prev_draft_token_indices.extend(range(start,
|
||||
start + draft_len))
|
||||
indices_match &= prev_index == flattened_index
|
||||
max_flattened_index = max(max_flattened_index, flattened_index)
|
||||
num_commmon_tokens = len(flattened_indices)
|
||||
if num_commmon_tokens < total_num_scheduled_tokens:
|
||||
num_commmon_tokens = len(sample_flattened_indices)
|
||||
total_without_spec = (total_num_scheduled_tokens -
|
||||
total_num_spec_tokens)
|
||||
if num_commmon_tokens < total_without_spec:
|
||||
# If not all requests are decodes from the last iteration,
|
||||
# We need to copy the input_ids_cpu to the NPU first.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
@@ -1256,21 +1350,45 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
non_blocking=True)
|
||||
self.is_token_ids.gpu[:num_commmon_tokens] = True
|
||||
return
|
||||
# Upload the index tensors asynchronously
|
||||
# so the scatter can be non-blocking.
|
||||
input_ids_index_tensor = torch.tensor(flattened_indices,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory).to(
|
||||
self.device,
|
||||
non_blocking=True)
|
||||
# Upload the index tensors asynchronously so the scatter can be non-blocking.
|
||||
sampled_tokens_index_tensor = torch.tensor(
|
||||
sample_flattened_indices,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
|
||||
prev_common_req_indices_tensor = torch.tensor(
|
||||
prev_common_req_indices,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
|
||||
self.input_ids.scatter_(dim=0,
|
||||
index=input_ids_index_tensor,
|
||||
src=self.input_batch.prev_sampled_token_ids[
|
||||
prev_common_req_indices_tensor, 0])
|
||||
self.input_ids.scatter_(
|
||||
dim=0,
|
||||
index=sampled_tokens_index_tensor,
|
||||
src=self.input_batch.prev_sampled_token_ids[
|
||||
prev_common_req_indices_tensor, 0],
|
||||
)
|
||||
|
||||
# scatter the draft tokens after the sampled tokens are scattered.
|
||||
if self._draft_token_ids is None or not spec_flattened_indices:
|
||||
return
|
||||
|
||||
assert isinstance(self._draft_token_ids, torch.Tensor)
|
||||
draft_tokens_index_tensor = torch.tensor(
|
||||
spec_flattened_indices,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
|
||||
prev_draft_token_indices_tensor = torch.tensor(
|
||||
prev_draft_token_indices,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
|
||||
|
||||
# because input_ids dtype is torch.int32,
|
||||
# so convert draft_token_ids to torch.int32 here.
|
||||
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
|
||||
self._draft_token_ids = None
|
||||
self.input_ids.scatter_(
|
||||
dim=0,
|
||||
index=draft_tokens_index_tensor,
|
||||
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
|
||||
)
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
@@ -1544,7 +1662,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
# Copy the tensors to the NPU.
|
||||
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
||||
self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens,
|
||||
cu_num_tokens)
|
||||
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
|
||||
self.positions[:num_input_tokens].copy_(
|
||||
self.positions_cpu[:num_input_tokens], non_blocking=True)
|
||||
@@ -1993,8 +2112,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
cu_num_scheduled_tokens - num_sampled_tokens,
|
||||
num_sampled_tokens)
|
||||
logits_indices_pcp += arange
|
||||
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).to(
|
||||
self.device, non_blocking=True)
|
||||
logits_indices_pcp = torch.from_numpy(
|
||||
logits_indices_pcp).pin_memory().to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
# Compute the bonus logits indices.
|
||||
bonus_logits_indices = cu_num_sampled_tokens - 1
|
||||
@@ -2015,16 +2135,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
target_logits_indices += arange
|
||||
|
||||
# TODO: Optimize the CPU -> NPU copy.
|
||||
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
self.device, non_blocking=True)
|
||||
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
|
||||
self.device, non_blocking=True)
|
||||
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
||||
non_blocking=True)
|
||||
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
cu_num_draft_tokens = (
|
||||
torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
|
||||
self.device, non_blocking=True))
|
||||
cu_num_sampled_tokens = (
|
||||
torch.from_numpy(cu_num_sampled_tokens).pin_memory().to(
|
||||
self.device, non_blocking=True))
|
||||
logits_indices = (torch.from_numpy(logits_indices).pin_memory().to(
|
||||
self.device, non_blocking=True))
|
||||
target_logits_indices = (
|
||||
torch.from_numpy(target_logits_indices).pin_memory().to(
|
||||
self.device, non_blocking=True))
|
||||
bonus_logits_indices = torch.from_numpy(
|
||||
bonus_logits_indices).pin_memory().to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
# Compute the draft token ids.
|
||||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||||
@@ -2466,7 +2590,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
if self.need_accepted_tokens:
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
|
||||
discard_sampled_tokens_req_indices = \
|
||||
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
@@ -2494,6 +2617,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
@@ -2514,13 +2638,14 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
|
||||
)
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
if self.num_spec_tokens <= 0:
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||||
# These will be copied into input_ids in the next step
|
||||
# when preparing inputs.
|
||||
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
||||
|
||||
|
||||
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
||||
# These will be copied into input_ids in the next step
|
||||
# when preparing inputs.
|
||||
self.input_batch.prev_sampled_token_ids = \
|
||||
sampled_token_ids
|
||||
self.input_batch.prev_sampled_token_ids_invalid_indices = \
|
||||
invalid_req_indices_set
|
||||
self.input_batch.prev_req_id_to_index = {
|
||||
@@ -2629,6 +2754,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
vocab_size=self.input_batch.vocab_size,
|
||||
)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
|
||||
@@ -68,6 +68,8 @@ class CachedRequestState:
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
prev_num_draft_len: int = 0 # previous number of draft tokens
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds)
|
||||
|
||||
Reference in New Issue
Block a user