[Feature] model_runner refactor (#4764)

### What this PR does / why we need it?
refactor npu_modelrunner, we should be close to gpu_modelrunner 

### Does this PR introduce _any_ user-facing change?
NO

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
zhenwenqi2024
2025-12-12 17:27:09 +08:00
committed by GitHub
parent 5b12c068f9
commit f708d919f8
10 changed files with 676 additions and 1815 deletions

View File

@@ -24,6 +24,7 @@ from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -67,6 +68,8 @@ def _compare_objs(obj1,
is_same = True # if we make it here must be same
elif a == b:
is_same = True
elif isinstance(a, CpuGpuBuffer):
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"

View File

@@ -1,113 +0,0 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from unittest.mock import MagicMock, patch
import pytest
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import AscendDeviceType
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
# yapf: disable
@pytest.mark.parametrize(
"soc_version, enable_expert_parallel, world_size, pipeline_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
[
# Case 1: Expert parallel is disabled, should always be 'allgather'
(AscendDeviceType._910B, False, 8, 2, 100, 256, None, MoECommType.ALLGATHER),
(AscendDeviceType._910_93, False, 16, 2, 500, 256, None, MoECommType.ALLGATHER),
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
(AscendDeviceType._910B, True, 8, 1, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendDeviceType._910B, True, 16, 1, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
(AscendDeviceType._910B, True, 16, 1, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
(AscendDeviceType._910B, True, 8, 2, 100, 256, None, MoECommType.ALLGATHER),
(AscendDeviceType._910B, True, 16, 2, 257, 256, None, MoECommType.ALLGATHER),
# Case 4: A3 SOC
(AscendDeviceType._910_93, True, 8, 2, 100, 256, None, MoECommType.MC2),
(AscendDeviceType._910_93, True, 8, 2, 257, 256, None, MoECommType.ALLTOALL),
])
# yapf: enable
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
world_size, pipeline_size, num_tokens,
mc2_tokens_capacity, quant_type,
expected_method):
"""
Tests the _select_moe_comm_method with various configurations including quant_type.
"""
# Mock the NPUModelRunner instance and its dependencies
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.parallel_config = MagicMock()
mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel
mock_runner.parallel_config.world_size_across_dp = world_size
mock_runner.parallel_config.pipeline_parallel_size = pipeline_size
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
# Add vllm_config.model_config.hf_config mock with moe_quantize
mock_hf_config = MagicMock()
mock_hf_config.moe_quantize = quant_type
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_runner.vllm_config = mock_vllm_config
# Patch the helper functions
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_device_type',
return_value=soc_version), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True), \
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
return_value=True):
# Bind the real method to the mock object
method = NPUModelRunner._select_moe_comm_method(
mock_runner, num_tokens)
# Assert the result
assert method == expected_method
def test_select_moe_comm_method_unsupported_soc():
"""
Tests that _select_moe_comm_method raises ValueError for an unsupported SOC.
"""
mock_runner = MagicMock(spec=NPUModelRunner)
mock_runner.parallel_config = MagicMock()
mock_runner.parallel_config.enable_expert_parallel = True
mock_runner.mc2_tokens_capacity = 256
# Add vllm_config.model_config.hf_config mock with moe_quantize
mock_hf_config = MagicMock()
mock_hf_config.moe_quantize = None
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config
mock_runner.vllm_config = mock_vllm_config
unsupported_soc = "UnsupportedSOC"
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_device_type',
return_value=unsupported_soc), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True), \
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
return_value=True), \
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
NPUModelRunner._select_moe_comm_method(mock_runner, 100)

View File

@@ -35,7 +35,6 @@ def set_ascend_forward_context(
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
reserved_mc2_mask: Optional[torch.Tensor] = None,
moe_comm_type: Optional[MoECommType] = None,
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
@@ -147,7 +146,7 @@ def set_ascend_forward_context(
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:forward_context.
padded_num_tokens]
@@ -159,3 +158,76 @@ def set_ascend_forward_context(
yield
finally:
pass
_mc2_tokens_capacity: Optional[int] = None
_reserved_mc2_mask: Optional[torch.Tensor] = None
_sin: Optional[torch.Tensor] = None
_cos: Optional[torch.Tensor] = None
def set_mc2_tokens_capacity(vllm_config, max_num_reqs,
uniform_decode_query_len):
global _mc2_tokens_capacity
if _mc2_tokens_capacity is not None:
return
if vllm_config.compilation_config.cudagraph_capture_sizes:
max_num_tokens = vllm_config.compilation_config.max_cudagraph_capture_size
else:
# NOTE: To save memory, we cap the max number of tokens to 512.
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
tp_size = vllm_config.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
_mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
def get_mc2_tokens_capacity():
return _mc2_tokens_capacity
def set_mc2_mask(vllm_config, device):
global _reserved_mc2_mask
if _reserved_mc2_mask is not None:
return
if is_moe_model(vllm_config):
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(),
dtype=torch.bool,
device=device)
else:
_reserved_mc2_mask = None
def get_mc2_mask():
return _reserved_mc2_mask
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device):
global _cos
global _sin
if _cos is not None:
return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
else:
_cos = None
_sin = None
def get_cos_and_sin():
return _cos, _sin

View File

@@ -25,6 +25,7 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
@@ -625,8 +626,7 @@ class AscendMLAMetadataBuilder:
decode_metadata = None
if num_decodes > 0:
cos = common_attn_metadata.cos
sin = common_attn_metadata.sin
cos, sin = get_cos_and_sin()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()

View File

@@ -16,6 +16,7 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
@@ -186,8 +187,7 @@ class AscendSFAMetadataBuilder:
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
cos = common_attn_metadata.cos
sin = common_attn_metadata.sin
cos, sin = get_cos_and_sin()
assert self.cos_cache is not None and self.sin_cache is not None
new_cos = self.cos_cache[input_positions][:, None, None]

View File

@@ -100,10 +100,6 @@ class AscendCommonAttentionMetadata:
# padding tokens. It is used to handle some padding operations.
num_input_tokens: int = 0
# NOTE: This is a temporary solution for rotary embedding in MLA
cos: torch.Tensor = None
sin: torch.Tensor = None
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None

View File

@@ -256,11 +256,12 @@ class MtpProposer(Proposer):
self.runner.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.runner.
query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.runner.seq_lens_cpu,
seq_lens=self.runner.seq_lens[:num_reqs],
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs +
1],
query_start_loc_cpu=self.runner.query_start_loc.
cpu[:num_reqs + 1],
seq_lens_cpu=self.runner.seq_lens.cpu,
seq_lens=self.runner.seq_lens.gpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
num_input_tokens=num_tokens,
@@ -268,16 +269,14 @@ class MtpProposer(Proposer):
num_computed_tokens_cpu=num_computed_tokens_cpu,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor()[:num_reqs],
get_device_tensor(),
slot_mapping=self.runner.input_batch.block_table[0].
slot_mapping,
positions=self.runner.positions,
slot_mapping.gpu,
positions=self.runner.positions.gpu,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
cos=self.runner.cos,
sin=self.runner.sin,
)
builder = self.runner.attn_groups[0][0].get_metadata_builder()
@@ -304,7 +303,6 @@ class MtpProposer(Proposer):
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0,
@@ -406,7 +404,8 @@ class MtpProposer(Proposer):
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
@@ -435,7 +434,7 @@ class MtpProposer(Proposer):
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids[token_indices]
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
@@ -748,7 +747,7 @@ class MtpProposer(Proposer):
uniform_decode = False
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
aclgraph_runtime_mode, batch_descriptor = \
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
if self.use_async_scheduling:
# there is synchronization between mtp steps when enabling aclgraph,
# disable aclgraph when use async scheduling to avoid the
@@ -781,7 +780,6 @@ class MtpProposer(Proposer):
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,

View File

@@ -4,6 +4,7 @@ import numpy as np
import torch
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class BlockTable:
@@ -76,32 +77,14 @@ class BlockTable:
duplicate_size = 1
if self.pcp_world_size > 1:
duplicate_size += num_speculative_tokens
self.block_table = torch.zeros(
(max_num_reqs * duplicate_size, logical_table_size),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs * duplicate_size, logical_table_size),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.block_table = self._make_buffer(max_num_reqs * duplicate_size,
logical_table_size,
dtype=torch.int32)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping_cpu = torch.zeros(
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens +
2 * self.pcp_world_size * self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(
self.max_num_batched_tokens +
2 * self.pcp_world_size * self.max_num_reqs,
dtype=torch.int32,
device=self.device)
dtype=torch.int32)
self.kernel_sizes = kernel_sizes
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
@@ -120,7 +103,7 @@ class BlockTable:
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] += num_blocks
def add_row(self, block_ids: list[int], row_idx: int) -> None:
@@ -129,7 +112,7 @@ class BlockTable:
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
self.block_table.np[tgt, :num_blocks] = self.block_table.np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
@@ -139,7 +122,7 @@ class BlockTable:
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]]
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
@@ -171,7 +154,7 @@ class BlockTable:
self.blocks_per_phys_block +
logical_block_idx)
block_numbers = self.block_table_np.ravel()[block_table_indices]
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
@@ -186,7 +169,7 @@ class BlockTable:
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
else:
assert self.kernel_sizes is not None
@@ -203,24 +186,22 @@ class BlockTable:
req_indices * self.max_num_blocks_per_req *
self.blocks_per_phys_block + logical_block_idx)
block_numbers = self.block_table_np.ravel(
block_numbers = self.block_table.np.ravel(
)[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:req_indices.shape[0]])
out=self.slot_mapping.np[:req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
self.block_table.copy_to_gpu(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping[:num_tokens].copy_(
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
self.slot_mapping.copy_to_gpu(num_tokens)
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
self.block_table.cpu.fill_(0)
def _convert_physical_to_logical_blocks(
self, physical_blocks: np.ndarray) -> np.ndarray:
@@ -243,15 +224,22 @@ class BlockTable:
def get_device_tensor(self) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table
return self.block_table.gpu
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu
return self.block_table.cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
return self.block_table.np
def _make_buffer(self, *size: int | torch.SymInt,
dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
class MultiGroupBlockTable:

File diff suppressed because it is too large Load Diff

View File

@@ -114,6 +114,7 @@ class InputBatch:
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
@@ -143,10 +144,11 @@ class InputBatch:
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
self.is_token_ids_tensor = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
self.is_token_ids = self.is_token_ids_tensor.numpy()
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
@@ -299,6 +301,11 @@ class InputBatch:
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
# Store last speculative tokens for sampler.
self.spec_token_ids: list[list[int]] = [[]
for _ in range(max_num_reqs)]
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@@ -306,9 +313,14 @@ class InputBatch:
self.pooling_params: dict[str, PoolingParams] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids: torch.Tensor | None = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
self.prev_req_id_to_index: dict[str, int] | None = None
# These are used to update output_token_ids with real sampled
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.Event | None = None
@property
def req_ids(self) -> list[str]:
@@ -350,9 +362,11 @@ class InputBatch:
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
self.spec_token_ids.append([])
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.spec_token_ids[req_index].clear()
self.req_id_to_index[req_id] = req_index
@@ -496,6 +510,21 @@ class InputBatch:
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index].clear()
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
@@ -510,6 +539,8 @@ class InputBatch:
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
if self.prev_req_id_to_index is not None:
self.prev_req_id_to_index.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
@@ -538,6 +569,10 @@ class InputBatch:
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
self.spec_token_ids[i2],
self.spec_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
@@ -629,6 +664,7 @@ class InputBatch:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
self.spec_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
@@ -662,6 +698,16 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
if last_req_index != empty_index:
(
self.spec_token_ids[last_req_index],
self.spec_token_ids[empty_index],
) = (
self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index],
)
self.spec_token_ids[last_req_index].clear()
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
@@ -714,6 +760,7 @@ class InputBatch:
# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""
@@ -787,6 +834,7 @@ class InputBatch:
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
@@ -848,6 +896,53 @@ class InputBatch:
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
def set_async_sampled_token_ids(
self,
sampled_token_ids_cpu: torch.Tensor,
async_copy_ready_event: torch.Event,
) -> None:
"""
In async scheduling case, store ref to sampled_token_ids_cpu
tensor and corresponding copy-ready event. Used to repair
output_token_ids prior to sampling, if needed by logits processors.
"""
if self.sampling_metadata.output_token_ids:
self.sampled_token_ids_cpu = sampled_token_ids_cpu
self.async_copy_ready_event = async_copy_ready_event
else:
self.sampled_token_ids_cpu = None
self.async_copy_ready_event = None
def update_async_output_token_ids(self) -> None:
"""
In async scheduling case, update output_token_ids in sampling metadata
from prior steps sampled token ids once they've finished copying to CPU.
This is called right before they are needed by the logits processors.
"""
output_token_ids = self.sampling_metadata.output_token_ids
if self.sampled_token_ids_cpu is None or not output_token_ids:
# Output token ids not needed or not async scheduling.
return
assert self.prev_req_id_to_index is not None
sampled_token_ids = None
for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_output_token_ids = output_token_ids[index]
if not req_output_token_ids or req_output_token_ids[-1] != -1:
# Final output id is not a placeholder, some tokens must have
# been discarded after a kv-load failure.
continue
if sampled_token_ids is None:
assert self.async_copy_ready_event is not None
self.async_copy_ready_event.synchronize()
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(
-1).tolist()
# Replace placeholder token id with actual sampled id.
req_output_token_ids[-1] = sampled_token_ids[prev_index]
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)