[Misc] Upgrade vllm hash to 12_14 (#5000)

### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?
1. fix https://github.com/vllm-project/vllm/pull/27938
2. fix https://github.com/vllm-project/vllm/pull/27145
pooling models now supports chunked prefill and prefix caching,
3. fix https://github.com/vllm-project/vllm/pull/30181
define the CPU fields in the field config where they really belong.
4. fix https://github.com/vllm-project/vllm/pull/28168
define the CPU fields in the field config where they really belong.
5. fix https://github.com/vllm-project/vllm/pull/30201
some moudle rename
6. fix https://github.com/vllm-project/vllm/pull/29067
fusedmoe moudle refactor
7. fix https://github.com/vllm-project/vllm/pull/29066
fusedmoe moudle refactor
8. fix https://github.com/vllm-project/vllm/pull/29624
### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-12-15 19:54:23 +08:00
committed by GitHub
parent 3b7eb5179f
commit 8d2998d0e4
17 changed files with 167 additions and 1183 deletions

View File

@@ -74,7 +74,7 @@ jobs:
name: e2e-full name: e2e-full
strategy: strategy:
matrix: matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0] vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0]
needs: [changes] needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' }} if: ${{ needs.changes.outputs.e2e_tracker == 'true' }}
uses: ./.github/workflows/_e2e_test.yaml uses: ./.github/workflows/_e2e_test.yaml

View File

@@ -42,7 +42,7 @@ jobs:
lint: lint:
uses: ./.github/workflows/_pre_commit.yml uses: ./.github/workflows/_pre_commit.yml
with: with:
vllm: ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 vllm: 97f2f160fda2805f9149b0e44da76b5d3b1f7c7e
changes: changes:
runs-on: linux-aarch64-a2-0 runs-on: linux-aarch64-a2-0
outputs: outputs:
@@ -90,7 +90,7 @@ jobs:
SOC_VERSION: ascend910b1 SOC_VERSION: ascend910b1
strategy: strategy:
matrix: matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0] vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0]
steps: steps:
- name: Free up disk space - name: Free up disk space
@@ -154,7 +154,7 @@ jobs:
name: e2e-light name: e2e-light
strategy: strategy:
matrix: matrix:
vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0] vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0]
# Note (yikun): If CI resource are limited we can split job into two chain jobs # Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes] needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request. # only trigger e2e test after lint passed and the change is e2e related with pull request.

View File

@@ -45,7 +45,7 @@ The table below is the release compatibility matrix for vLLM Ascend release.
For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly. For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly.
| vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu | | vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu |
|-------------|--------------|------------------|-------------|--------------------| |-------------|--------------|------------------|-------------|--------------------|
| main | ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 | | main | 97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 |
## Release cadence ## Release cadence

View File

@@ -803,7 +803,9 @@ class TestPCPDCPGraphParams(TestBase):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads, (q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
scale, num_kv_heads, out, lse)) scale, num_kv_heads, out, lse))
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, 4) with patch("torch_npu._C._npu_setStream", return_value=None):
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
4)
_mock_graph_task_end.assert_called_once() _mock_graph_task_end.assert_called_once()
@@ -842,6 +844,7 @@ class TestPCPDCPGraphParams(TestBase):
block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q, block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q,
out, lse, 2, 0, 0)) out, lse, 2, 0, 0))
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4) with patch("torch_npu._C._npu_setStream", return_value=None):
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
_mock_graph_task_end.assert_called_once() _mock_graph_task_end.assert_called_once()

View File

@@ -95,6 +95,8 @@ class TestEagleProposerLoadModel(TestBase):
mock_model = MagicMock() mock_model = MagicMock()
mock_model.model.embed_tokens = MagicMock() mock_model.model.embed_tokens = MagicMock()
mock_model.lm_head = MagicMock() mock_model.lm_head = MagicMock()
mock_model.multimodal_cpu_fields = None
mock_model.merge_by_field_config = None
mock_get_model.return_value = MagicMock() mock_get_model.return_value = MagicMock()
self.proposer.name = SpecDcodeType.EAGLE self.proposer.name = SpecDcodeType.EAGLE
@@ -117,6 +119,8 @@ class TestEagleProposerLoadModel(TestBase):
mock_model = MagicMock() mock_model = MagicMock()
original_embed = MagicMock() original_embed = MagicMock()
mock_model.multimodal_cpu_fields = None
mock_model.merge_by_field_config = None
mock_get_model.return_value = MagicMock(model=MagicMock( mock_get_model.return_value = MagicMock(model=MagicMock(
embed_tokens=original_embed)) embed_tokens=original_embed))

View File

@@ -1,375 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import inspect
from collections.abc import Sequence
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.sampling_params import SamplingParams
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
MAX_NUM_PROMPT_TOKENS = 64
def _compare_objs(obj1,
obj2,
skip: Sequence = ("logitsprocs", "batch_update_builder")):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
if attr_name in skip:
continue
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)
is_same = False
if isinstance(a, torch.Tensor):
if (a.numel() == 0 or b.numel() == 0):
is_same = (a.numel() == 0 and b.numel() == 0)
elif torch.allclose(a, b):
is_same = True
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
is_same = True
elif isinstance(a, CpuGpuBuffer):
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"
def _remove_requests(input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> set[str]:
"""
Remove some requests randomly from the batch and returns
set of request removed
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove: set[int] = set()
for _ in range(num_reqs_to_remove):
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return req_ids_to_remove
def _construct_expected_sampling_metadata(
reqs: list[CachedRequestState],
req_ids_retained: set[int],
req_id_index_in_input_batch: dict[str, int],
device: torch.device,
) -> SamplingMetadata:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
"""
num_reqs = len(req_ids_retained)
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
presence_penalties = [0.0 for _ in range(num_reqs)]
frequency_penalties = [0.0 for _ in range(num_reqs)]
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
allowed_token_ids_mask = torch.zeros(num_reqs,
VOCAB_SIZE,
dtype=torch.bool,
device=device)
bad_words_token_ids = {}
for req in reqs:
if req.req_id not in req_ids_retained:
continue
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
output_token_ids[index_in_input_batch] = req.output_token_ids
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[index_in_input_batch] = (
req.sampling_params.frequency_penalty)
repetition_penalties[index_in_input_batch] = (
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
if req.sampling_params.bad_words_token_ids:
bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
prompt_token_ids,
pad=VOCAB_SIZE,
device=torch.device(device),
dtype=torch.int64,
),
frequency_penalties=torch.tensor(frequency_penalties,
dtype=torch.float,
device=device),
presence_penalties=torch.tensor(presence_penalties,
dtype=torch.float,
device=device),
repetition_penalties=torch.tensor(repetition_penalties,
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessors(),
)
def _create_sampling_params():
return SamplingParams(
top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0),
min_tokens=np.random.randint(1, 10),
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10))
],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)
def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
]
output_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_kwargs=[],
mm_positions=[],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
mm_hashes=None,
)
@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert req_index == assigned_req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
# Remove some requests
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
# Compact the input batch
input_batch.condense()
# Generate the sampling metadata
sampling_metadata = input_batch._make_sampling_metadata()
# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
reqs,
req_ids_retained,
input_batch.req_id_to_index,
device=torch.device(device))
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
return (t1 is None
and t2 is None) or (t1 is not None and t2 is not None
and torch.allclose(t1, t2))
# Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature)
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
)
assert torch.allclose(
expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)
assert expected_sampling_metadata.bad_words_token_ids == \
sampling_metadata.bad_words_token_ids
@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("swap_list", [((0, 1), )])
def test_swap_states_in_input_batch(device: str, batch_size: int,
swap_list: list):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert assigned_req_index == req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
reordered_reqs = reqs.copy()
for swap_pair in swap_list:
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
input_batch.swap_states(swap_pair[0], swap_pair[1])
for req_index in range(batch_size):
req = reordered_reqs[req_index]
assigned_req_index = ref_input_batch.add_request(req)
assert assigned_req_index == req_index
input_batch.refresh_metadata()
ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch)

View File

@@ -41,7 +41,7 @@ from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
flashcomm2_o_shared_enabled, is_enable_nz, flashcomm2_o_shared_enabled, is_enable_nz,
weak_ref_tensors) weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@@ -280,7 +280,7 @@ class AscendMLAMetadataBuilder:
dtype=torch.uint8, dtype=torch.uint8,
device=device) device=device)
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at # We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount # the front and the "prefill" requests are at the using the least amount

View File

@@ -32,7 +32,7 @@ from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
_round_up, dispose_layer, enable_sp, _round_up, dispose_layer, enable_sp,
is_enable_nz, replace_layer) is_enable_nz, replace_layer)
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@@ -149,7 +149,7 @@ class AscendSFAMetadataBuilder:
self.enable_sfa_cp = enable_sp() and \ self.enable_sfa_cp = enable_sp() and \
hasattr(self.model_config.hf_config, "index_topk") hasattr(self.model_config.hf_config, "index_topk")
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
# No need to reorder for Ascend SFA # No need to reorder for Ascend SFA
return False return False

View File

@@ -24,7 +24,7 @@ _MOE_LOAD_ASYNC_STREAM = None
def get_expert_map(self, layer_id): def get_expert_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_map() return self.model.layers[layer_id].mlp.experts.expert_map
def get_log2phy_map(self, layer_id): def get_log2phy_map(self, layer_id):

View File

@@ -153,7 +153,7 @@ class AscendFusedMoE(FusedMoE):
AscendFusedMoE.moe_counter += 1 AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter self.moe_instance_id = AscendFusedMoE.moe_counter
self.expert_map = None self._expert_map = None
self.log2phy = None self.log2phy = None
if self.quant_config is None: if self.quant_config is None:
@@ -184,7 +184,7 @@ class AscendFusedMoE(FusedMoE):
dtype=vllm_config.model_config.dtype) dtype=vllm_config.model_config.dtype)
# init moe. # init moe.
self.local_num_experts, self.expert_map, _ = determine_expert_map( self.local_num_experts, self._expert_map, _ = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts) self.ep_size, self.ep_rank, self.global_num_experts)
# TODO: Temporary flag to indicate if static EPLB is enabled. This is a # TODO: Temporary flag to indicate if static EPLB is enabled. This is a
# workaround to bypass a quantization check that fails with float weights. # workaround to bypass a quantization check that fails with float weights.
@@ -200,7 +200,7 @@ class AscendFusedMoE(FusedMoE):
self.expert_load_balancer.get_global_redundant_expert_num()) self.expert_load_balancer.get_global_redundant_expert_num())
self.global_num_experts = num_experts + self.global_redundant_expert_num self.global_num_experts = num_experts + self.global_redundant_expert_num
try: try:
self.local_num_experts, self.expert_map = ( self.local_num_experts, self._expert_map = (
self.expert_load_balancer.get_rank_placement_map( self.expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank)) self.moe_instance_id, self.ep_rank))
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
@@ -216,16 +216,16 @@ class AscendFusedMoE(FusedMoE):
if self.dynamic_eplb: if self.dynamic_eplb:
self.log2phy = determine_default_log2phy_map( self.log2phy = determine_default_log2phy_map(
self.global_num_experts, self.ep_size, self.ep_rank).npu() self.global_num_experts, self.ep_size, self.ep_rank).npu()
if self.expert_map is not None and isinstance(self.expert_map, if self._expert_map is not None and isinstance(self._expert_map,
torch.Tensor): torch.Tensor):
logger.info_once( logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global" "[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
" number of experts: %s/%s. Experts local to global index map:" " number of experts: %s/%s. Experts local to global index map:"
" %s.", self.ep_rank, self.ep_size, self.local_num_experts, " %s.", self.ep_rank, self.ep_size, self.local_num_experts,
self.global_num_experts, self.global_num_experts,
get_compressed_expert_map(self.expert_map)) get_compressed_expert_map(self._expert_map))
local_num_experts = (torch.sum( local_num_experts = (torch.sum(
self.expert_map != -1) if self.expert_map is not None else self._expert_map != -1) if self._expert_map is not None else
self.global_num_experts) self.global_num_experts)
if self.dynamic_eplb: if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, self.moe_load = torch.zeros(local_num_experts,
@@ -276,10 +276,16 @@ class AscendFusedMoE(FusedMoE):
return QuantType.NONE return QuantType.NONE
def update_expert_map(self, new_expert_map): def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map self._expert_map = new_expert_map
def get_map(self): @property
return self.expert_map def expert_map(self) -> torch.Tensor | None:
return self._expert_map
@expert_map.setter
def expert_map(self, new_expert_map):
# TODO(Potabk): Remove this once we drop vllm v0.12.0(This makes backward compatibility with vllm v0.12.0)
self._expert_map = new_expert_map
def get_log2phy_map(self): def get_log2phy_map(self):
return self.log2phy return self.log2phy

View File

@@ -17,10 +17,15 @@
import os import os
import vllm_ascend.patch.platform.patch_distributed # noqa import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_ec_connector # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa
from vllm_ascend.utils import vllm_version_is
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv( if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv(
"EXPERT_MAP_RECORD", "false") == "true": "EXPERT_MAP_RECORD", "false") == "true":
import vllm_ascend.patch.platform.patch_multiproc_executor # noqa import vllm_ascend.patch.platform.patch_multiproc_executor # noqa
if vllm_version_is("0.12.0"):
import vllm_ascend.patch.platform.patch_ec_connector012 # noqa
else:
import vllm_ascend.patch.platform.patch_ec_connector # noqa

View File

@@ -1,16 +1,15 @@
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import vllm.distributed.ec_transfer.ec_connector.example_connector
from safetensors.torch import load_file from safetensors.torch import load_file
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata from vllm.distributed.ec_transfer.ec_connector.example_connector import (
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( ECConnectorMetadata, ECExampleConnector)
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
from vllm.logger import logger from vllm.logger import logger
class AscendECSharedStorageConnector(ECSharedStorageConnector): class AscendECExampleConnector(ECExampleConnector):
def start_load_caches(self, encoder_cache, **kwargs) -> None: def start_load_caches(self, encoder_cache, **kwargs) -> None:
metadata: ECConnectorMetadata = self._get_connector_metadata() metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata) assert isinstance(metadata, ECConnectorMetadata)
assert encoder_cache is not None assert encoder_cache is not None
if metadata is None: if metadata is None:
logger.warning(( logger.warning((
@@ -29,4 +28,4 @@ class AscendECSharedStorageConnector(ECSharedStorageConnector):
mm_data.mm_hash) mm_data.mm_hash)
vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector vllm.distributed.ec_transfer.ec_connector.example_connector.ECExampleConnector = AscendECExampleConnector

View File

@@ -0,0 +1,33 @@
import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector # type: ignore[import-not-found] # noqa
from safetensors.torch import load_file
from vllm.distributed.ec_transfer.ec_connector.base import \
ECConnectorMetadata # type: ignore[import-not-found] # noqa
from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( # type: ignore[import-not-found] # noqa
ECSharedStorageConnector, ECSharedStorageConnectorMetadata)
from vllm.logger import logger
class AscendECSharedStorageConnector(ECSharedStorageConnector):
def start_load_caches(self, encoder_cache, **kwargs) -> None:
metadata: ECConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, ECSharedStorageConnectorMetadata)
assert encoder_cache is not None
if metadata is None:
logger.warning((
"In connector.start_load_caches, ",
"but the connector metadata is None",
))
return
# Load the EC for each mm data
for mm_data in metadata.mm_datas:
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = load_file(filename)["ec_cache"].npu()
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s",
mm_data.mm_hash)
vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector

View File

@@ -365,6 +365,10 @@ class NPUPlatform(Platform):
use_mla, use_mla,
has_sink=False, has_sink=False,
use_sparse=False, use_sparse=False,
# NOTE: Please pay special attention to the order of these parameters.
# Although we are only using some of them so far
# vllm passes them in sequence when using this interface.
use_mm_prefix: bool = False,
attn_type: str | None = None, attn_type: str | None = None,
): ):
# choose attention backend based on use_mla # choose attention backend based on use_mla

View File

@@ -476,9 +476,10 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
# Calculate maximum supported batch sizes considering model architecture # Calculate maximum supported batch sizes considering model architecture
resources_per_graph = num_hidden_layers + 1 resources_per_graph = num_hidden_layers + 1
if vllm_config.speculative_config is not None: # For suffix decoding, use the suffix path when no draft_model_config is provided.
draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config if (spec := vllm_config.speculative_config) and \
resources_per_graph += draft_model_hf_config.num_hidden_layers + 1 (draft := spec.draft_model_config):
resources_per_graph += draft.hf_config.num_hidden_layers + 1
# TODO: Find out whether we need to take into account the pp_size # TODO: Find out whether we need to take into account the pp_size
num_comm_groups = sum(size > 1 for size in [ num_comm_groups = sum(size > 1 for size in [

View File

@@ -121,8 +121,8 @@ from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendDeviceType, ProfileExecuteDuration, AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_enable_nz, enable_sp, get_ascend_device_type, is_enable_nz,
is_moe_model, lmhead_tp_enable) is_moe_model, lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped] import xgrammar as xgr # type: ignore[import-untyped]
@@ -249,13 +249,24 @@ class NPUModelRunner(GPUModelRunner):
# Set up Attention # Set up Attention
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
"index_topk") "index_topk")
self.attn_backend = get_attn_backend(0, if vllm_version_is('0.12.0'):
self.dtype, self.attn_backend = get_attn_backend(
None, 0,
self.block_size, self.dtype,
use_mla=self.model_config.use_mla, None,
use_sparse=self.use_sparse) self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse)
else:
self.attn_backend = get_attn_backend(
0,
self.dtype,
None,
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse,
use_mm_prefix=self.model_config is not None
and self.model_config.is_mm_prefix_lm)
self.attn_mask_builder = AttentionMaskBuilder(self.device) self.attn_mask_builder = AttentionMaskBuilder(self.device)
self._set_up_drafter() self._set_up_drafter()
@@ -353,7 +364,7 @@ class NPUModelRunner(GPUModelRunner):
# solution, we initialize the input batch here, and re-initialize it # solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from # in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config. # the block_sizes in the kv cache config.
self.input_batch = InputBatch( self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len, max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
@@ -2000,19 +2011,36 @@ class NPUModelRunner(GPUModelRunner):
self.speculative_config.method == "mtp": self.speculative_config.method == "mtp":
attn_state = AscendAttentionState.SpecDecoding attn_state = AscendAttentionState.SpecDecoding
common_metadata = CommonAttentionMetadata( if vllm_version_is("0.12.0"):
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], common_metadata = CommonAttentionMetadata(
query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + query_start_loc=self.query_start_loc.gpu[:num_reqs +
1], 1],
seq_lens_cpu=self.seq_lens.cpu[:num_reqs], query_start_loc_cpu=self.query_start_loc.
seq_lens=self.seq_lens.cpu[:num_reqs], cpu[:num_reqs + 1],
num_reqs=num_reqs, seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
num_actual_tokens=num_tokens, seq_lens=self.seq_lens.cpu[:num_reqs],
block_table_tensor=block_table_tensor[:num_reqs], num_reqs=num_reqs,
slot_mapping=slot_mapping.gpu, num_actual_tokens=num_tokens,
num_computed_tokens_cpu=num_computed_tokens_cpu, block_table_tensor=block_table_tensor[:num_reqs],
max_query_len=max_query_len, slot_mapping=slot_mapping.gpu,
max_seq_len=seq_lens) num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
max_seq_len=seq_lens)
else:
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[:num_reqs +
1],
query_start_loc_cpu=self.query_start_loc.
cpu[:num_reqs + 1],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
seq_lens=self.seq_lens.cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
block_table_tensor=block_table_tensor[:num_reqs],
slot_mapping=slot_mapping.gpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
max_seq_len=seq_lens)
for attn_group in self.attn_groups[kv_cache_group_id]: for attn_group in self.attn_groups[kv_cache_group_id]:
builder = attn_group.get_metadata_builder() builder = attn_group.get_metadata_builder()
@@ -2778,7 +2806,7 @@ class NPUModelRunner(GPUModelRunner):
"Cannot re-initialize the input batch when CPU weight " "Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details.") "for more details.")
self.input_batch = InputBatch( self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len, max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,

View File

@@ -17,92 +17,29 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py # Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
# #
from dataclasses import dataclass
from typing import Optional, cast
import numpy as np import numpy as np
import torch import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitsProcessors, LogitsProcessors)
MoveDirectionality) from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm_ascend.worker.block_table import MultiGroupBlockTable from vllm_ascend.worker.block_table import MultiGroupBlockTable
@dataclass class PoolingStates:
class CachedRequestState: # NOTE: This should be removed after we drop support of vLLM v0.12.0
def __init__(self):
# for chunked prefill with ALL pooling
self.hidden_states_cache: list[torch.Tensor] = []
req_id: str def clean(self):
prompt_token_ids: Optional[list[int]] self.hidden_states_cache.clear()
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
mm_features: Optional[list[MultiModalFeatureSpec]] = None
# for back-compatibility, will be removed in next major release
mm_kwargs: Optional[list[MultiModalKwargsItem]] = None
mm_positions: Optional[list[PlaceholderRange]] = None
mm_hashes: Optional[list[PlaceholderRange]] = None
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
prev_num_draft_len: int = 0 # previous number of draft tokens
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
assert self.mm_features is not None
return [
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
if f.data is not None
]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown.")
return self.prompt_token_ids[idx]
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
else:
return -1
class InputBatch: class NPUInputBatch(InputBatch):
def __init__( def __init__(
self, self,
@@ -113,12 +50,12 @@ class InputBatch:
pin_memory: bool, pin_memory: bool,
vocab_size: int, vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None, kernel_block_sizes: list[list[int]],
logitsprocs: LogitsProcessors | None = None,
logitsprocs_need_output_token_ids: bool = False, logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False, is_spec_decode: bool = False,
is_pooling_model: bool = False, is_pooling_model: bool = False,
num_speculative_tokens: int = 0, num_speculative_tokens: int = 0,
kernel_block_sizes: Optional[list[list[int]]] = None,
cp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
): ):
self.is_pooling_model = is_pooling_model self.is_pooling_model = is_pooling_model
@@ -130,12 +67,12 @@ class InputBatch:
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.vocab_size = vocab_size self.vocab_size = vocab_size
self._req_ids: list[Optional[str]] = [] self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {} self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big. # TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage. # Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the NPU, so it does not # This buffer is not directly transferred to the GPU, so it does not
# need to be pinned. # need to be pinned.
self.token_ids_cpu_tensor = torch.zeros( self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len), (max_num_reqs, max_model_len),
@@ -162,8 +99,8 @@ class InputBatch:
dtype=torch.int32, dtype=torch.int32,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
self.num_computed_tokens_cpu = \ self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy(
self.num_computed_tokens_cpu_tensor.numpy() )
# Block table. # Block table.
self.block_table = MultiGroupBlockTable( self.block_table = MultiGroupBlockTable(
@@ -222,8 +159,8 @@ class InputBatch:
dtype=torch.float, dtype=torch.float,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.frequency_penalties_cpu = \ self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy(
self.frequency_penalties_cpu_tensor.numpy() )
self.frequency_penalties_reqs: set[str] = set() self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures # Presence penalty related data structures
@@ -247,8 +184,8 @@ class InputBatch:
dtype=torch.float, dtype=torch.float,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.repetition_penalties_cpu = \ self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy(
self.repetition_penalties_cpu_tensor.numpy() )
self.repetition_penalties_reqs: set[str] = set() self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding # Speculative decoding
@@ -256,12 +193,12 @@ class InputBatch:
dtype=torch.int64, dtype=torch.int64,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.num_accepted_tokens_cpu = \ self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy(
self.num_accepted_tokens_cpu_tensor.numpy() )
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ), self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32) dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
@@ -271,9 +208,6 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {} self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {} self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps. # To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
@@ -287,8 +221,8 @@ class InputBatch:
self.has_allowed_token_ids: set[str] = set() self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed, # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf. # the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None self.allowed_token_ids_mask: torch.Tensor | None = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
# req_index -> bad_words_token_ids # req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {} self.bad_words_token_ids: dict[int, list[list[int]]] = {}
@@ -296,7 +230,7 @@ class InputBatch:
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
dtype=bool) dtype=bool)
self.req_output_token_ids: list[Optional[list[int]]] = [] self.req_output_token_ids: list[list[int] | None] = []
# Store provided logitsprocs. If none are provided, initialize empty # Store provided logitsprocs. If none are provided, initialize empty
# data structure # data structure
@@ -310,673 +244,15 @@ class InputBatch:
# This is updated each time the batch constituents change. # This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata()
# for pooling models
self.pooling_params: dict[str, PoolingParams] = {} self.pooling_params: dict[str, PoolingParams] = {}
self.pooling_states: dict[str, PoolingStates] = {}
# Cached reference to the GPU tensor of previously sampled tokens # Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: torch.Tensor | None = None self.prev_sampled_token_ids: torch.Tensor | None = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: dict[str, int] | None = None self.prev_req_id_to_index: dict[str, int] | None = None
# These are used to update output_token_ids with real sampled # These are used to update output_token_ids with real sampled
# ids from prior step, if required by current sampling params # ids from prior step, if required by current sampling params
# (e.g. penalties). # (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.Event | None = None self.async_copy_ready_event: torch.Event | None = None
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations for logits processors.
Not applicable to pooling models.
"""
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert request.sampling_params
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params, request.prompt_token_ids,
request.output_token_ids))
return new_req_index
def add_request(
self,
request: "CachedRequestState",
) -> int:
if not self.is_pooling_model:
# New request index bookkeeping for autoregressive models.
req_index = self._register_add_request(request)
else:
req_index = self.num_reqs
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
self.spec_token_ids.append([])
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.spec_token_ids[req_index].clear()
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
elif pooling_params := request.pooling_params:
self.pooling_params[req_id] = pooling_params
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError(request)
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
return req_index
def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense().
Args:
req_id: request to remove
Returns:
Removed request index, or `None` if `req_id` not recognized
"""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
if not self.is_pooling_model:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index].clear()
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.spec_decode_unsupported_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
if self.prev_req_id_to_index is not None:
self.prev_req_id_to_index.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
self.spec_token_ids[i2],
self.spec_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
embeds_i2 = self.req_prompt_embeds.get(i2)
if embeds_i1 is not None:
self.req_prompt_embeds[i2] = embeds_i1
else:
self.req_prompt_embeds.pop(i2, None)
if embeds_i2 is not None:
self.req_prompt_embeds[i1] = embeds_i2
else:
self.req_prompt_embeds.pop(i1, None)
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2)
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
Any consecutive empty indices at the very end of the list are not
filled.
Args:
empty_req_indices: empty indices which may be filled.
Returns:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
num_reqs = self.num_reqs
if self.is_pooling_model:
# Will be contiguous in pooling case, just trim the lists.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
return
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
self.spec_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = self.batch_update_builder.peek_removed()
assert empty_index is not None
if empty_index >= last_req_index:
break
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
if last_req_index != empty_index:
(
self.spec_token_ids[last_req_index],
self.spec_token_ids[empty_index],
) = (
self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index],
)
self.spec_token_ids[last_req_index].clear()
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens]
if last_req_index in self.req_prompt_embeds:
self.req_prompt_embeds[
empty_index] = self.req_prompt_embeds.pop(last_req_index)
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.num_accepted_tokens_cpu[
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]
bad_words_token_ids = self.bad_words_token_ids.pop(
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""
if self.is_pooling_model:
# Batch changes every step for pooling models.
self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs)
copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any())
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
def set_async_sampled_token_ids(
self,
sampled_token_ids_cpu: torch.Tensor,
async_copy_ready_event: torch.Event,
) -> None:
"""
In async scheduling case, store ref to sampled_token_ids_cpu
tensor and corresponding copy-ready event. Used to repair
output_token_ids prior to sampling, if needed by logits processors.
"""
if self.sampling_metadata.output_token_ids:
self.sampled_token_ids_cpu = sampled_token_ids_cpu
self.async_copy_ready_event = async_copy_ready_event
else:
self.sampled_token_ids_cpu = None
self.async_copy_ready_event = None
def update_async_output_token_ids(self) -> None:
"""
In async scheduling case, update output_token_ids in sampling metadata
from prior steps sampled token ids once they've finished copying to CPU.
This is called right before they are needed by the logits processors.
"""
output_token_ids = self.sampling_metadata.output_token_ids
if self.sampled_token_ids_cpu is None or not output_token_ids:
# Output token ids not needed or not async scheduling.
return
assert self.prev_req_id_to_index is not None
sampled_token_ids = None
for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_output_token_ids = output_token_ids[index]
if not req_output_token_ids or req_output_token_ids[-1] != -1:
# Final output id is not a placeholder, some tokens must have
# been discarded after a kv-load failure.
continue
if sampled_token_ids is None:
assert self.async_copy_ready_event is not None
self.async_copy_ready_event.synchronize()
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(
-1).tolist()
# Replace placeholder token id with actual sampled id.
req_output_token_ids[-1] = sampled_token_ids[prev_index]
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)
@property
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0