[Feature] model_runner refactor (#4764)
### What this PR does / why we need it?
refactor npu_modelrunner, we should be close to gpu_modelrunner
### Does this PR introduce _any_ user-facing change?
NO
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from vllm.utils.torch_utils import make_tensor_with_pad
|
|||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
|
|
||||||
from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable
|
from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||||
@@ -67,6 +68,8 @@ def _compare_objs(obj1,
|
|||||||
is_same = True # if we make it here must be same
|
is_same = True # if we make it here must be same
|
||||||
elif a == b:
|
elif a == b:
|
||||||
is_same = True
|
is_same = True
|
||||||
|
elif isinstance(a, CpuGpuBuffer):
|
||||||
|
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
|
||||||
assert is_same, f"Attribute {attr_name} is different"\
|
assert is_same, f"Attribute {attr_name} is different"\
|
||||||
f" in {obj1} and {obj2}: {a} != {b}"
|
f" in {obj1} and {obj2}: {a} != {b}"
|
||||||
|
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
|
||||||
from vllm_ascend.utils import AscendDeviceType
|
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"soc_version, enable_expert_parallel, world_size, pipeline_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
|
|
||||||
[
|
|
||||||
# Case 1: Expert parallel is disabled, should always be 'allgather'
|
|
||||||
(AscendDeviceType._910B, False, 8, 2, 100, 256, None, MoECommType.ALLGATHER),
|
|
||||||
(AscendDeviceType._910_93, False, 16, 2, 500, 256, None, MoECommType.ALLGATHER),
|
|
||||||
|
|
||||||
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
|
|
||||||
(AscendDeviceType._910B, True, 8, 1, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
|
|
||||||
(AscendDeviceType._910B, True, 16, 1, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
|
|
||||||
(AscendDeviceType._910B, True, 16, 1, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
|
|
||||||
|
|
||||||
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
|
|
||||||
(AscendDeviceType._910B, True, 8, 2, 100, 256, None, MoECommType.ALLGATHER),
|
|
||||||
(AscendDeviceType._910B, True, 16, 2, 257, 256, None, MoECommType.ALLGATHER),
|
|
||||||
|
|
||||||
# Case 4: A3 SOC
|
|
||||||
(AscendDeviceType._910_93, True, 8, 2, 100, 256, None, MoECommType.MC2),
|
|
||||||
(AscendDeviceType._910_93, True, 8, 2, 257, 256, None, MoECommType.ALLTOALL),
|
|
||||||
])
|
|
||||||
# yapf: enable
|
|
||||||
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
|
||||||
world_size, pipeline_size, num_tokens,
|
|
||||||
mc2_tokens_capacity, quant_type,
|
|
||||||
expected_method):
|
|
||||||
"""
|
|
||||||
Tests the _select_moe_comm_method with various configurations including quant_type.
|
|
||||||
"""
|
|
||||||
# Mock the NPUModelRunner instance and its dependencies
|
|
||||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
|
||||||
mock_runner.parallel_config = MagicMock()
|
|
||||||
mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel
|
|
||||||
mock_runner.parallel_config.world_size_across_dp = world_size
|
|
||||||
mock_runner.parallel_config.pipeline_parallel_size = pipeline_size
|
|
||||||
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
|
|
||||||
|
|
||||||
# Add vllm_config.model_config.hf_config mock with moe_quantize
|
|
||||||
mock_hf_config = MagicMock()
|
|
||||||
mock_hf_config.moe_quantize = quant_type
|
|
||||||
mock_model_config = MagicMock()
|
|
||||||
mock_model_config.hf_config = mock_hf_config
|
|
||||||
mock_vllm_config = MagicMock()
|
|
||||||
mock_vllm_config.model_config = mock_model_config
|
|
||||||
mock_runner.vllm_config = mock_vllm_config
|
|
||||||
|
|
||||||
# Patch the helper functions
|
|
||||||
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_device_type',
|
|
||||||
return_value=soc_version), \
|
|
||||||
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
|
|
||||||
return_value=True), \
|
|
||||||
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
|
|
||||||
return_value=True):
|
|
||||||
|
|
||||||
# Bind the real method to the mock object
|
|
||||||
method = NPUModelRunner._select_moe_comm_method(
|
|
||||||
mock_runner, num_tokens)
|
|
||||||
|
|
||||||
# Assert the result
|
|
||||||
assert method == expected_method
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_moe_comm_method_unsupported_soc():
|
|
||||||
"""
|
|
||||||
Tests that _select_moe_comm_method raises ValueError for an unsupported SOC.
|
|
||||||
"""
|
|
||||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
|
||||||
mock_runner.parallel_config = MagicMock()
|
|
||||||
mock_runner.parallel_config.enable_expert_parallel = True
|
|
||||||
mock_runner.mc2_tokens_capacity = 256
|
|
||||||
|
|
||||||
# Add vllm_config.model_config.hf_config mock with moe_quantize
|
|
||||||
mock_hf_config = MagicMock()
|
|
||||||
mock_hf_config.moe_quantize = None
|
|
||||||
mock_model_config = MagicMock()
|
|
||||||
mock_model_config.hf_config = mock_hf_config
|
|
||||||
mock_vllm_config = MagicMock()
|
|
||||||
mock_vllm_config.model_config = mock_model_config
|
|
||||||
mock_runner.vllm_config = mock_vllm_config
|
|
||||||
|
|
||||||
unsupported_soc = "UnsupportedSOC"
|
|
||||||
|
|
||||||
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_device_type',
|
|
||||||
return_value=unsupported_soc), \
|
|
||||||
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
|
|
||||||
return_value=True), \
|
|
||||||
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
|
|
||||||
return_value=True), \
|
|
||||||
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
|
|
||||||
|
|
||||||
NPUModelRunner._select_moe_comm_method(mock_runner, 100)
|
|
||||||
@@ -35,7 +35,6 @@ def set_ascend_forward_context(
|
|||||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
with_prefill: bool = True,
|
with_prefill: bool = True,
|
||||||
in_profile_run: bool = False,
|
in_profile_run: bool = False,
|
||||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
|
||||||
moe_comm_type: Optional[MoECommType] = None,
|
moe_comm_type: Optional[MoECommType] = None,
|
||||||
num_actual_tokens: Optional[int] = None,
|
num_actual_tokens: Optional[int] = None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
@@ -147,7 +146,7 @@ def set_ascend_forward_context(
|
|||||||
# NOTE: token num which need to pad to when mc2
|
# NOTE: token num which need to pad to when mc2
|
||||||
forward_context.padded_num_tokens = math.ceil(
|
forward_context.padded_num_tokens = math.ceil(
|
||||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||||
|
reserved_mc2_mask = get_mc2_mask()
|
||||||
if reserved_mc2_mask is not None:
|
if reserved_mc2_mask is not None:
|
||||||
mc2_mask = reserved_mc2_mask[:forward_context.
|
mc2_mask = reserved_mc2_mask[:forward_context.
|
||||||
padded_num_tokens]
|
padded_num_tokens]
|
||||||
@@ -159,3 +158,76 @@ def set_ascend_forward_context(
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_mc2_tokens_capacity: Optional[int] = None
|
||||||
|
_reserved_mc2_mask: Optional[torch.Tensor] = None
|
||||||
|
_sin: Optional[torch.Tensor] = None
|
||||||
|
_cos: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_mc2_tokens_capacity(vllm_config, max_num_reqs,
|
||||||
|
uniform_decode_query_len):
|
||||||
|
global _mc2_tokens_capacity
|
||||||
|
if _mc2_tokens_capacity is not None:
|
||||||
|
return
|
||||||
|
if vllm_config.compilation_config.cudagraph_capture_sizes:
|
||||||
|
max_num_tokens = vllm_config.compilation_config.max_cudagraph_capture_size
|
||||||
|
else:
|
||||||
|
# NOTE: To save memory, we cap the max number of tokens to 512.
|
||||||
|
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
||||||
|
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||||
|
# Use integer arithmetic for ceiling division.
|
||||||
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||||
|
_mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
||||||
|
|
||||||
|
|
||||||
|
def get_mc2_tokens_capacity():
|
||||||
|
return _mc2_tokens_capacity
|
||||||
|
|
||||||
|
|
||||||
|
def set_mc2_mask(vllm_config, device):
|
||||||
|
global _reserved_mc2_mask
|
||||||
|
if _reserved_mc2_mask is not None:
|
||||||
|
return
|
||||||
|
if is_moe_model(vllm_config):
|
||||||
|
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device)
|
||||||
|
else:
|
||||||
|
_reserved_mc2_mask = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mc2_mask():
|
||||||
|
return _reserved_mc2_mask
|
||||||
|
|
||||||
|
|
||||||
|
def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||||
|
device):
|
||||||
|
global _cos
|
||||||
|
global _sin
|
||||||
|
if _cos is not None:
|
||||||
|
return
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||||
|
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||||
|
_cos = torch.ones(max_num_reqs * decode_token_per_req,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
rope_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
_sin = torch.zeros(max_num_reqs * decode_token_per_req,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
rope_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
else:
|
||||||
|
_cos = None
|
||||||
|
_sin = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cos_and_sin():
|
||||||
|
return _cos, _sin
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
|||||||
|
|
||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import get_cos_and_sin
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
@@ -625,8 +626,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
cos = common_attn_metadata.cos
|
cos, sin = get_cos_and_sin()
|
||||||
sin = common_attn_metadata.sin
|
|
||||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||||
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
|
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
|
||||||
1].tolist()
|
1].tolist()
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
|||||||
|
|
||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import get_cos_and_sin
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
@@ -186,8 +187,7 @@ class AscendSFAMetadataBuilder:
|
|||||||
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
|
||||||
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
seq_lens = common_attn_metadata.seq_lens[:num_reqs]
|
||||||
|
|
||||||
cos = common_attn_metadata.cos
|
cos, sin = get_cos_and_sin()
|
||||||
sin = common_attn_metadata.sin
|
|
||||||
|
|
||||||
assert self.cos_cache is not None and self.sin_cache is not None
|
assert self.cos_cache is not None and self.sin_cache is not None
|
||||||
new_cos = self.cos_cache[input_positions][:, None, None]
|
new_cos = self.cos_cache[input_positions][:, None, None]
|
||||||
|
|||||||
@@ -100,10 +100,6 @@ class AscendCommonAttentionMetadata:
|
|||||||
# padding tokens. It is used to handle some padding operations.
|
# padding tokens. It is used to handle some padding operations.
|
||||||
num_input_tokens: int = 0
|
num_input_tokens: int = 0
|
||||||
|
|
||||||
# NOTE: This is a temporary solution for rotary embedding in MLA
|
|
||||||
cos: torch.Tensor = None
|
|
||||||
sin: torch.Tensor = None
|
|
||||||
|
|
||||||
prefill_context_parallel_metadata: Optional[
|
prefill_context_parallel_metadata: Optional[
|
||||||
AscendPrefillContextParallelMetadata] = None
|
AscendPrefillContextParallelMetadata] = None
|
||||||
|
|
||||||
|
|||||||
@@ -256,11 +256,12 @@ class MtpProposer(Proposer):
|
|||||||
self.runner.input_batch.
|
self.runner.input_batch.
|
||||||
num_computed_tokens_cpu_tensor[:num_reqs])
|
num_computed_tokens_cpu_tensor[:num_reqs])
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=self.runner.query_start_loc[:num_reqs + 1],
|
query_start_loc=self.runner.query_start_loc.gpu[:num_reqs +
|
||||||
query_start_loc_cpu=self.runner.
|
1],
|
||||||
query_start_loc_cpu[:num_reqs + 1],
|
query_start_loc_cpu=self.runner.query_start_loc.
|
||||||
seq_lens_cpu=self.runner.seq_lens_cpu,
|
cpu[:num_reqs + 1],
|
||||||
seq_lens=self.runner.seq_lens[:num_reqs],
|
seq_lens_cpu=self.runner.seq_lens.cpu,
|
||||||
|
seq_lens=self.runner.seq_lens.gpu[:num_reqs],
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
num_input_tokens=num_tokens,
|
num_input_tokens=num_tokens,
|
||||||
@@ -268,16 +269,14 @@ class MtpProposer(Proposer):
|
|||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||||
get_device_tensor()[:num_reqs],
|
get_device_tensor(),
|
||||||
slot_mapping=self.runner.input_batch.block_table[0].
|
slot_mapping=self.runner.input_batch.block_table[0].
|
||||||
slot_mapping,
|
slot_mapping.gpu,
|
||||||
positions=self.runner.positions,
|
positions=self.runner.positions.gpu,
|
||||||
attn_mask=self.runner.attn_mask,
|
attn_mask=self.runner.attn_mask,
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
attn_state=self.runner.attn_state,
|
attn_state=self.runner.attn_state,
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
cos=self.runner.cos,
|
|
||||||
sin=self.runner.sin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||||
@@ -304,7 +303,6 @@ class MtpProposer(Proposer):
|
|||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
|
||||||
moe_comm_type=moe_comm_type,
|
moe_comm_type=moe_comm_type,
|
||||||
in_profile_run=self.runner.in_profile_run,
|
in_profile_run=self.runner.in_profile_run,
|
||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
@@ -406,7 +404,8 @@ class MtpProposer(Proposer):
|
|||||||
else:
|
else:
|
||||||
token_indices_to_sample = None
|
token_indices_to_sample = None
|
||||||
# input_ids can be None for multimodal models.
|
# input_ids can be None for multimodal models.
|
||||||
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
target_token_ids = self.runner.input_ids.gpu[:
|
||||||
|
num_scheduled_tokens]
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
else:
|
else:
|
||||||
@@ -435,7 +434,7 @@ class MtpProposer(Proposer):
|
|||||||
target_positions = positions
|
target_positions = positions
|
||||||
target_hidden_states = hidden_states
|
target_hidden_states = hidden_states
|
||||||
else:
|
else:
|
||||||
target_token_ids = self.runner.input_ids[token_indices]
|
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
||||||
target_positions = positions[token_indices]
|
target_positions = positions[token_indices]
|
||||||
target_hidden_states = hidden_states[token_indices]
|
target_hidden_states = hidden_states[token_indices]
|
||||||
|
|
||||||
@@ -748,7 +747,7 @@ class MtpProposer(Proposer):
|
|||||||
uniform_decode = False
|
uniform_decode = False
|
||||||
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
|
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
|
||||||
aclgraph_runtime_mode, batch_descriptor = \
|
aclgraph_runtime_mode, batch_descriptor = \
|
||||||
self.runner.aclgraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
self.runner.cudagraph_dispatcher.dispatch(num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
|
||||||
if self.use_async_scheduling:
|
if self.use_async_scheduling:
|
||||||
# there is synchronization between mtp steps when enabling aclgraph,
|
# there is synchronization between mtp steps when enabling aclgraph,
|
||||||
# disable aclgraph when use async scheduling to avoid the
|
# disable aclgraph when use async scheduling to avoid the
|
||||||
@@ -781,7 +780,6 @@ class MtpProposer(Proposer):
|
|||||||
num_tokens=num_input_tokens,
|
num_tokens=num_input_tokens,
|
||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
|
||||||
moe_comm_type=moe_comm_type,
|
moe_comm_type=moe_comm_type,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
|
|
||||||
|
|
||||||
class BlockTable:
|
class BlockTable:
|
||||||
@@ -76,32 +77,14 @@ class BlockTable:
|
|||||||
duplicate_size = 1
|
duplicate_size = 1
|
||||||
if self.pcp_world_size > 1:
|
if self.pcp_world_size > 1:
|
||||||
duplicate_size += num_speculative_tokens
|
duplicate_size += num_speculative_tokens
|
||||||
self.block_table = torch.zeros(
|
self.block_table = self._make_buffer(max_num_reqs * duplicate_size,
|
||||||
(max_num_reqs * duplicate_size, logical_table_size),
|
logical_table_size,
|
||||||
device=self.device,
|
dtype=torch.int32)
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
self.block_table_cpu = torch.zeros(
|
|
||||||
(max_num_reqs * duplicate_size, logical_table_size),
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
)
|
|
||||||
self.block_table_np = self.block_table_cpu.numpy()
|
|
||||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
|
self.slot_mapping = self._make_buffer(
|
||||||
self.slot_mapping_cpu = torch.zeros(
|
|
||||||
self.max_num_batched_tokens +
|
self.max_num_batched_tokens +
|
||||||
2 * self.pcp_world_size * self.max_num_reqs,
|
2 * self.pcp_world_size * self.max_num_reqs,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32)
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
||||||
self.slot_mapping = torch.zeros(
|
|
||||||
self.max_num_batched_tokens +
|
|
||||||
2 * self.pcp_world_size * self.max_num_reqs,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
self.kernel_sizes = kernel_sizes
|
self.kernel_sizes = kernel_sizes
|
||||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||||
@@ -120,7 +103,7 @@ class BlockTable:
|
|||||||
num_blocks = len(block_ids)
|
num_blocks = len(block_ids)
|
||||||
start = self.num_blocks_per_row[row_idx]
|
start = self.num_blocks_per_row[row_idx]
|
||||||
|
|
||||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
|
||||||
self.num_blocks_per_row[row_idx] += num_blocks
|
self.num_blocks_per_row[row_idx] += num_blocks
|
||||||
|
|
||||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||||
@@ -129,7 +112,7 @@ class BlockTable:
|
|||||||
|
|
||||||
def move_row(self, src: int, tgt: int) -> None:
|
def move_row(self, src: int, tgt: int) -> None:
|
||||||
num_blocks = self.num_blocks_per_row[src]
|
num_blocks = self.num_blocks_per_row[src]
|
||||||
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
self.block_table.np[tgt, :num_blocks] = self.block_table.np[
|
||||||
src, :num_blocks]
|
src, :num_blocks]
|
||||||
self.num_blocks_per_row[tgt] = num_blocks
|
self.num_blocks_per_row[tgt] = num_blocks
|
||||||
|
|
||||||
@@ -139,7 +122,7 @@ class BlockTable:
|
|||||||
self.num_blocks_per_row[src] = num_blocks_tgt
|
self.num_blocks_per_row[src] = num_blocks_tgt
|
||||||
self.num_blocks_per_row[tgt] = num_blocks_src
|
self.num_blocks_per_row[tgt] = num_blocks_src
|
||||||
|
|
||||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]]
|
||||||
|
|
||||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||||
positions: np.ndarray) -> None:
|
positions: np.ndarray) -> None:
|
||||||
@@ -171,7 +154,7 @@ class BlockTable:
|
|||||||
self.blocks_per_phys_block +
|
self.blocks_per_phys_block +
|
||||||
logical_block_idx)
|
logical_block_idx)
|
||||||
|
|
||||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||||
# Use virtual_block_size for mask calculation, which marks local
|
# Use virtual_block_size for mask calculation, which marks local
|
||||||
# tokens.
|
# tokens.
|
||||||
virtual_block_offsets = positions % virtual_block_size
|
virtual_block_offsets = positions % virtual_block_size
|
||||||
@@ -186,7 +169,7 @@ class BlockTable:
|
|||||||
# Calculate slot_mapping
|
# Calculate slot_mapping
|
||||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||||
# Write final slots, use -1 for not-local
|
# Write final slots, use -1 for not-local
|
||||||
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
|
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
|
||||||
mask, slot_mapping, -1)
|
mask, slot_mapping, -1)
|
||||||
else:
|
else:
|
||||||
assert self.kernel_sizes is not None
|
assert self.kernel_sizes is not None
|
||||||
@@ -203,24 +186,22 @@ class BlockTable:
|
|||||||
req_indices * self.max_num_blocks_per_req *
|
req_indices * self.max_num_blocks_per_req *
|
||||||
self.blocks_per_phys_block + logical_block_idx)
|
self.blocks_per_phys_block + logical_block_idx)
|
||||||
|
|
||||||
block_numbers = self.block_table_np.ravel(
|
block_numbers = self.block_table.np.ravel(
|
||||||
)[block_table_indices]
|
)[block_table_indices]
|
||||||
block_offsets = positions % self.block_size
|
block_offsets = positions % self.block_size
|
||||||
np.add(block_numbers * self.block_size,
|
np.add(block_numbers * self.block_size,
|
||||||
block_offsets,
|
block_offsets,
|
||||||
out=self.slot_mapping_np[:req_indices.shape[0]])
|
out=self.slot_mapping.np[:req_indices.shape[0]])
|
||||||
|
|
||||||
def commit_block_table(self, num_reqs: int) -> None:
|
def commit_block_table(self, num_reqs: int) -> None:
|
||||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
self.block_table.copy_to_gpu(num_reqs)
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||||
self.slot_mapping[:num_tokens].copy_(
|
self.slot_mapping.copy_to_gpu(num_tokens)
|
||||||
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.block_table.fill_(0)
|
self.block_table.fill_(0)
|
||||||
self.block_table_cpu.fill_(0)
|
self.block_table.cpu.fill_(0)
|
||||||
|
|
||||||
def _convert_physical_to_logical_blocks(
|
def _convert_physical_to_logical_blocks(
|
||||||
self, physical_blocks: np.ndarray) -> np.ndarray:
|
self, physical_blocks: np.ndarray) -> np.ndarray:
|
||||||
@@ -243,15 +224,22 @@ class BlockTable:
|
|||||||
|
|
||||||
def get_device_tensor(self) -> torch.Tensor:
|
def get_device_tensor(self) -> torch.Tensor:
|
||||||
"""Returns the device tensor of the block table."""
|
"""Returns the device tensor of the block table."""
|
||||||
return self.block_table
|
return self.block_table.gpu
|
||||||
|
|
||||||
def get_cpu_tensor(self) -> torch.Tensor:
|
def get_cpu_tensor(self) -> torch.Tensor:
|
||||||
"""Returns the CPU tensor of the block table."""
|
"""Returns the CPU tensor of the block table."""
|
||||||
return self.block_table_cpu
|
return self.block_table.cpu
|
||||||
|
|
||||||
def get_numpy_array(self) -> np.ndarray:
|
def get_numpy_array(self) -> np.ndarray:
|
||||||
"""Returns the numpy array of the block table."""
|
"""Returns the numpy array of the block table."""
|
||||||
return self.block_table_np
|
return self.block_table.np
|
||||||
|
|
||||||
|
def _make_buffer(self, *size: int | torch.SymInt,
|
||||||
|
dtype: torch.dtype) -> CpuGpuBuffer:
|
||||||
|
return CpuGpuBuffer(*size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory)
|
||||||
|
|
||||||
|
|
||||||
class MultiGroupBlockTable:
|
class MultiGroupBlockTable:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -114,6 +114,7 @@ class InputBatch:
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
block_sizes: list[int], # The block_size of each kv cache group
|
block_sizes: list[int], # The block_size of each kv cache group
|
||||||
logitsprocs: Optional[LogitsProcessors] = None,
|
logitsprocs: Optional[LogitsProcessors] = None,
|
||||||
|
logitsprocs_need_output_token_ids: bool = False,
|
||||||
is_spec_decode: bool = False,
|
is_spec_decode: bool = False,
|
||||||
is_pooling_model: bool = False,
|
is_pooling_model: bool = False,
|
||||||
num_speculative_tokens: int = 0,
|
num_speculative_tokens: int = 0,
|
||||||
@@ -143,10 +144,11 @@ class InputBatch:
|
|||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
)
|
)
|
||||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||||
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
|
self.is_token_ids_tensor = torch.zeros((max_num_reqs, max_model_len),
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=bool,
|
dtype=bool,
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
|
self.is_token_ids = self.is_token_ids_tensor.numpy()
|
||||||
# Store prompt embeddings per request to avoid OOM from large upfront
|
# Store prompt embeddings per request to avoid OOM from large upfront
|
||||||
# allocation if max_model_len is big.
|
# allocation if max_model_len is big.
|
||||||
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
||||||
@@ -299,6 +301,11 @@ class InputBatch:
|
|||||||
# Store provided logitsprocs. If none are provided, initialize empty
|
# Store provided logitsprocs. If none are provided, initialize empty
|
||||||
# data structure
|
# data structure
|
||||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||||
|
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
|
||||||
|
|
||||||
|
# Store last speculative tokens for sampler.
|
||||||
|
self.spec_token_ids: list[list[int]] = [[]
|
||||||
|
for _ in range(max_num_reqs)]
|
||||||
|
|
||||||
# This is updated each time the batch constituents change.
|
# This is updated each time the batch constituents change.
|
||||||
self.sampling_metadata = self._make_sampling_metadata()
|
self.sampling_metadata = self._make_sampling_metadata()
|
||||||
@@ -306,9 +313,14 @@ class InputBatch:
|
|||||||
self.pooling_params: dict[str, PoolingParams] = {}
|
self.pooling_params: dict[str, PoolingParams] = {}
|
||||||
|
|
||||||
# Cached reference to the GPU tensor of previously sampled tokens
|
# Cached reference to the GPU tensor of previously sampled tokens
|
||||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
self.prev_sampled_token_ids: torch.Tensor | None = None
|
||||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
self.prev_req_id_to_index: dict[str, int] | None = None
|
||||||
|
# These are used to update output_token_ids with real sampled
|
||||||
|
# ids from prior step, if required by current sampling params
|
||||||
|
# (e.g. penalties).
|
||||||
|
self.sampled_token_ids_cpu: torch.Tensor | None = None
|
||||||
|
self.async_copy_ready_event: torch.Event | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def req_ids(self) -> list[str]:
|
def req_ids(self) -> list[str]:
|
||||||
@@ -350,9 +362,11 @@ class InputBatch:
|
|||||||
if req_index == len(self._req_ids):
|
if req_index == len(self._req_ids):
|
||||||
self._req_ids.append(req_id)
|
self._req_ids.append(req_id)
|
||||||
self.req_output_token_ids.append(request.output_token_ids)
|
self.req_output_token_ids.append(request.output_token_ids)
|
||||||
|
self.spec_token_ids.append([])
|
||||||
else:
|
else:
|
||||||
self._req_ids[req_index] = req_id
|
self._req_ids[req_index] = req_id
|
||||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||||
|
self.spec_token_ids[req_index].clear()
|
||||||
|
|
||||||
self.req_id_to_index[req_id] = req_index
|
self.req_id_to_index[req_id] = req_index
|
||||||
|
|
||||||
@@ -496,6 +510,21 @@ class InputBatch:
|
|||||||
self.batch_update_builder.removed_append(req_index)
|
self.batch_update_builder.removed_append(req_index)
|
||||||
self._req_ids[req_index] = None
|
self._req_ids[req_index] = None
|
||||||
self.req_output_token_ids[req_index] = None
|
self.req_output_token_ids[req_index] = None
|
||||||
|
self.spec_token_ids[req_index].clear()
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
lora_id = self.request_lora_mapping[req_index]
|
||||||
|
if lora_id != 0:
|
||||||
|
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
||||||
|
lora_req_ids.discard(req_id)
|
||||||
|
if not lora_req_ids:
|
||||||
|
del self.lora_id_to_request_ids[lora_id]
|
||||||
|
del self.lora_id_to_lora_request[lora_id]
|
||||||
|
self.request_lora_mapping[req_index] = 0
|
||||||
|
|
||||||
|
if self.is_pooling_model:
|
||||||
|
self.pooling_params.pop(req_id, None)
|
||||||
|
return req_index
|
||||||
|
|
||||||
self.greedy_reqs.discard(req_id)
|
self.greedy_reqs.discard(req_id)
|
||||||
self.random_reqs.discard(req_id)
|
self.random_reqs.discard(req_id)
|
||||||
@@ -510,6 +539,8 @@ class InputBatch:
|
|||||||
self.num_prompt_logprobs.pop(req_id, None)
|
self.num_prompt_logprobs.pop(req_id, None)
|
||||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||||
|
|
||||||
|
if self.prev_req_id_to_index is not None:
|
||||||
|
self.prev_req_id_to_index.pop(req_id, None)
|
||||||
# LoRA
|
# LoRA
|
||||||
lora_id = self.request_lora_mapping[req_index]
|
lora_id = self.request_lora_mapping[req_index]
|
||||||
if lora_id != 0:
|
if lora_id != 0:
|
||||||
@@ -538,6 +569,10 @@ class InputBatch:
|
|||||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||||
|
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
|
||||||
|
self.spec_token_ids[i2],
|
||||||
|
self.spec_token_ids[i1],
|
||||||
|
)
|
||||||
assert old_id_i1 is not None and old_id_i2 is not None
|
assert old_id_i1 is not None and old_id_i2 is not None
|
||||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||||
@@ -629,6 +664,7 @@ class InputBatch:
|
|||||||
# The batched states are empty.
|
# The batched states are empty.
|
||||||
self._req_ids.clear()
|
self._req_ids.clear()
|
||||||
self.req_output_token_ids.clear()
|
self.req_output_token_ids.clear()
|
||||||
|
self.spec_token_ids.clear()
|
||||||
return
|
return
|
||||||
|
|
||||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||||
@@ -662,6 +698,16 @@ class InputBatch:
|
|||||||
self.req_output_token_ids[last_req_index] = None
|
self.req_output_token_ids[last_req_index] = None
|
||||||
self.req_id_to_index[req_id] = empty_index
|
self.req_id_to_index[req_id] = empty_index
|
||||||
|
|
||||||
|
if last_req_index != empty_index:
|
||||||
|
(
|
||||||
|
self.spec_token_ids[last_req_index],
|
||||||
|
self.spec_token_ids[empty_index],
|
||||||
|
) = (
|
||||||
|
self.spec_token_ids[empty_index],
|
||||||
|
self.spec_token_ids[last_req_index],
|
||||||
|
)
|
||||||
|
self.spec_token_ids[last_req_index].clear()
|
||||||
|
|
||||||
num_tokens = self.num_tokens[last_req_index]
|
num_tokens = self.num_tokens[last_req_index]
|
||||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||||
last_req_index, :num_tokens]
|
last_req_index, :num_tokens]
|
||||||
@@ -714,6 +760,7 @@ class InputBatch:
|
|||||||
# Trim lists to the batch size.
|
# Trim lists to the batch size.
|
||||||
del self._req_ids[num_reqs:]
|
del self._req_ids[num_reqs:]
|
||||||
del self.req_output_token_ids[num_reqs:]
|
del self.req_output_token_ids[num_reqs:]
|
||||||
|
del self.spec_token_ids[num_reqs:]
|
||||||
|
|
||||||
def refresh_metadata(self):
|
def refresh_metadata(self):
|
||||||
"""Apply any batch updates to sampling metadata."""
|
"""Apply any batch updates to sampling metadata."""
|
||||||
@@ -787,6 +834,7 @@ class InputBatch:
|
|||||||
presence_penalties=self.presence_penalties[:num_reqs],
|
presence_penalties=self.presence_penalties[:num_reqs],
|
||||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||||
|
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
|
||||||
no_penalties=self.no_penalties,
|
no_penalties=self.no_penalties,
|
||||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||||
bad_words_token_ids=self.bad_words_token_ids,
|
bad_words_token_ids=self.bad_words_token_ids,
|
||||||
@@ -848,6 +896,53 @@ class InputBatch:
|
|||||||
|
|
||||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||||
|
|
||||||
|
def set_async_sampled_token_ids(
|
||||||
|
self,
|
||||||
|
sampled_token_ids_cpu: torch.Tensor,
|
||||||
|
async_copy_ready_event: torch.Event,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
In async scheduling case, store ref to sampled_token_ids_cpu
|
||||||
|
tensor and corresponding copy-ready event. Used to repair
|
||||||
|
output_token_ids prior to sampling, if needed by logits processors.
|
||||||
|
"""
|
||||||
|
if self.sampling_metadata.output_token_ids:
|
||||||
|
self.sampled_token_ids_cpu = sampled_token_ids_cpu
|
||||||
|
self.async_copy_ready_event = async_copy_ready_event
|
||||||
|
else:
|
||||||
|
self.sampled_token_ids_cpu = None
|
||||||
|
self.async_copy_ready_event = None
|
||||||
|
|
||||||
|
def update_async_output_token_ids(self) -> None:
|
||||||
|
"""
|
||||||
|
In async scheduling case, update output_token_ids in sampling metadata
|
||||||
|
from prior steps sampled token ids once they've finished copying to CPU.
|
||||||
|
This is called right before they are needed by the logits processors.
|
||||||
|
"""
|
||||||
|
output_token_ids = self.sampling_metadata.output_token_ids
|
||||||
|
if self.sampled_token_ids_cpu is None or not output_token_ids:
|
||||||
|
# Output token ids not needed or not async scheduling.
|
||||||
|
return
|
||||||
|
|
||||||
|
assert self.prev_req_id_to_index is not None
|
||||||
|
sampled_token_ids = None
|
||||||
|
for index, req_id in enumerate(self.req_ids):
|
||||||
|
prev_index = self.prev_req_id_to_index.get(req_id)
|
||||||
|
if prev_index is None:
|
||||||
|
continue
|
||||||
|
req_output_token_ids = output_token_ids[index]
|
||||||
|
if not req_output_token_ids or req_output_token_ids[-1] != -1:
|
||||||
|
# Final output id is not a placeholder, some tokens must have
|
||||||
|
# been discarded after a kv-load failure.
|
||||||
|
continue
|
||||||
|
if sampled_token_ids is None:
|
||||||
|
assert self.async_copy_ready_event is not None
|
||||||
|
self.async_copy_ready_event.synchronize()
|
||||||
|
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(
|
||||||
|
-1).tolist()
|
||||||
|
# Replace placeholder token id with actual sampled id.
|
||||||
|
req_output_token_ids[-1] = sampled_token_ids[prev_index]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_reqs(self) -> int:
|
def num_reqs(self) -> int:
|
||||||
return len(self.req_id_to_index)
|
return len(self.req_id_to_index)
|
||||||
|
|||||||
Reference in New Issue
Block a user