335 lines
16 KiB
Python
335 lines
16 KiB
Python
|
|
from typing import Any, Optional, Union
|
|
import numpy as np
|
|
import torch
|
|
from vllm import envs
|
|
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
|
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1
|
|
from vllm.utils import async_tensor_h2d
|
|
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
|
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
from vllm.v1.worker.block_table import BlockTable
|
|
|
|
|
|
class TBOModelInputSplit():
|
|
def __init__(self):
|
|
self.req_ids_left = []
|
|
self.req_ids_right = []
|
|
self.req_num_left = 0
|
|
self.req_num_right = 0
|
|
self.scheduler_output_left = None
|
|
self.scheduler_output_right = None
|
|
self.query_start_loc_right = None
|
|
|
|
input_split = TBOModelInputSplit()
|
|
|
|
def split_scheduler_output(runner, scheduler_output:SchedulerOutput):
|
|
split_tokens = scheduler_output.total_num_scheduled_tokens // 2
|
|
req_ids = runner.input_batch.req_ids
|
|
tokens_counter = 0
|
|
min_idx = -1
|
|
min_counter = 0
|
|
for i, id in enumerate(req_ids):
|
|
tokens_counter += scheduler_output.num_scheduled_tokens[id]
|
|
diff = abs(tokens_counter - split_tokens)
|
|
if min_idx == -1 or diff < min_counter:
|
|
min_idx = i
|
|
min_counter = diff
|
|
if tokens_counter > split_tokens or diff == 0:
|
|
break
|
|
input_split.req_num_left = min_idx + 1
|
|
if input_split.req_num_left == len(req_ids):
|
|
input_split.req_num_left = input_split.req_num_left - 1
|
|
input_split.req_ids_left = req_ids[:input_split.req_num_left]
|
|
input_split.req_ids_right = req_ids[input_split.req_num_left:]
|
|
input_split.req_num_right = len(req_ids) - input_split.req_num_left
|
|
new_req_data_left = []
|
|
new_req_data_right = []
|
|
cached_reqs_left = []
|
|
cached_reqs_right = []
|
|
num_scheduled_tokens_left = {}
|
|
num_scheduled_tokens_right = {}
|
|
total_num_scheduled_tokens_left = 0
|
|
total_num_scheduled_tokens_right = 0
|
|
for new_req in scheduler_output.scheduled_new_reqs:
|
|
if new_req.req_id in input_split.req_ids_left:
|
|
new_req_data_left.append(new_req)
|
|
else:
|
|
new_req_data_right.append(new_req)
|
|
|
|
cached_reqs_left = CachedRequestData.make_empty()
|
|
cached_reqs_right = CachedRequestData.make_empty()
|
|
for req_idx, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids):
|
|
if req_id in input_split.req_ids_left:
|
|
cached_reqs_left.req_ids.append(req_id)
|
|
cached_reqs_left.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx])
|
|
if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0:
|
|
cached_reqs_left.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx])
|
|
cached_reqs_left.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx])
|
|
cached_reqs_left.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx])
|
|
else:
|
|
cached_reqs_right.req_ids.append(req_id)
|
|
cached_reqs_right.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx])
|
|
if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0:
|
|
cached_reqs_right.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx])
|
|
cached_reqs_right.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx])
|
|
cached_reqs_right.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx])
|
|
for key, value in scheduler_output.num_scheduled_tokens.items():
|
|
if key in input_split.req_ids_left:
|
|
num_scheduled_tokens_left[key] = value
|
|
total_num_scheduled_tokens_left += value
|
|
else:
|
|
num_scheduled_tokens_right[key] = value
|
|
total_num_scheduled_tokens_right += value
|
|
|
|
|
|
input_split.scheduler_output_left = SchedulerOutput(
|
|
scheduled_new_reqs=new_req_data_left,
|
|
scheduled_cached_reqs=cached_reqs_left,
|
|
num_scheduled_tokens=num_scheduled_tokens_left,
|
|
total_num_scheduled_tokens=total_num_scheduled_tokens_left,
|
|
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
|
|
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet
|
|
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
|
|
# finished_req_ids is an existing state in the scheduler,
|
|
# instead of being newly scheduled in this step.
|
|
# It contains the request IDs that are finished in between
|
|
# the previous and the current steps.
|
|
finished_req_ids=scheduler_output.finished_req_ids,
|
|
free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
|
|
structured_output_request_ids=scheduler_output.structured_output_request_ids,
|
|
grammar_bitmask=scheduler_output.grammar_bitmask,
|
|
)
|
|
input_split.scheduler_output_right = SchedulerOutput(
|
|
scheduled_new_reqs=new_req_data_right,
|
|
scheduled_cached_reqs=cached_reqs_right,
|
|
num_scheduled_tokens=num_scheduled_tokens_right,
|
|
total_num_scheduled_tokens=total_num_scheduled_tokens_right,
|
|
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
|
|
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet
|
|
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
|
|
# finished_req_ids is an existing state in the scheduler,
|
|
# instead of being newly scheduled in this step.
|
|
# It contains the request IDs that are finished in between
|
|
# the previous and the current steps.
|
|
finished_req_ids=scheduler_output.finished_req_ids,
|
|
free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
|
|
structured_output_request_ids=scheduler_output.structured_output_request_ids,
|
|
grammar_bitmask=scheduler_output.grammar_bitmask,
|
|
)
|
|
|
|
|
|
def prepare_tbo_atten_metadata(
|
|
runner,
|
|
scheduler_output: "SchedulerOutput",
|
|
req_ids,
|
|
req_offset
|
|
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
assert total_num_scheduled_tokens > 0
|
|
num_reqs = len(req_ids)
|
|
assert num_reqs > 0
|
|
|
|
seq_len_offset = req_offset
|
|
# Get the number of scheduled tokens for each request.
|
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
|
max_num_scheduled_tokens = max(tokens)
|
|
|
|
if req_offset > 0: #right
|
|
if input_split.query_start_loc_right == None:
|
|
# TODO: create when system init
|
|
input_split.query_start_loc_right = torch.zeros(runner.max_num_reqs + 1,
|
|
dtype=torch.int32,
|
|
device=runner.device)
|
|
|
|
cu_num_tokens, arange = runner._get_cumsum_and_arange(
|
|
num_scheduled_tokens)
|
|
|
|
# Prepare the attention metadata.
|
|
runner.query_start_loc_np[0] = 0
|
|
runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
|
|
|
|
|
input_split.query_start_loc_right[0: num_reqs + 1].copy_(
|
|
runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
|
# Note: pad query_start_loc to be non-decreasing, as kernels
|
|
# like FlashAttention requires that
|
|
input_split.query_start_loc_right[num_reqs + 1:].fill_(
|
|
runner.query_start_loc_cpu[num_reqs].item())
|
|
query_start_loc = input_split.query_start_loc_right[: num_reqs + 1]
|
|
|
|
|
|
else:
|
|
query_start_loc = runner.query_start_loc[:num_reqs + 1]
|
|
|
|
|
|
seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs]
|
|
|
|
common_attn_metadata = CommonAttentionMetadata(
|
|
query_start_loc=query_start_loc,
|
|
seq_lens=seq_lens,
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
|
max_query_len=max_num_scheduled_tokens)
|
|
|
|
attn_metadata: dict[str, Any] = {}
|
|
# Prepare the attention metadata for each KV cache group and make layers
|
|
# in the same group share the same metadata.
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
runner.kv_cache_config.kv_cache_groups):
|
|
|
|
# Prepare for cascade attention if enabled & beneficial.
|
|
common_prefix_len = 0
|
|
metadata_builder = runner.attn_metadata_builders[kv_cache_group_id]
|
|
if runner.cascade_attn_enabled:
|
|
common_prefix_len = runner._compute_cascade_attn_prefix_len(
|
|
num_scheduled_tokens,
|
|
scheduler_output.
|
|
num_common_prefix_blocks[kv_cache_group_id],
|
|
kv_cache_group_spec.kv_cache_spec,
|
|
metadata_builder,
|
|
)
|
|
if req_offset > 0:
|
|
origin_block_table = metadata_builder.block_table.block_table
|
|
metadata_builder.block_table.block_table = origin_block_table[req_offset:, :]
|
|
origin_slot_mapping = metadata_builder.block_table.slot_mapping
|
|
metadata_builder.block_table.slot_mapping = \
|
|
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:]
|
|
origin_slot_map_cpu = metadata_builder.block_table.slot_mapping_cpu
|
|
metadata_builder.block_table.slot_mapping_cpu = \
|
|
origin_slot_map_cpu[input_split.scheduler_output_left.total_num_scheduled_tokens:]
|
|
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
|
|
_num_decodes_record = metadata_builder._num_decodes
|
|
_num_prefills_record = metadata_builder._num_prefills
|
|
_num_decode_tokens_record = metadata_builder._num_decode_tokens
|
|
_num_prefill_tokens_record = metadata_builder._num_prefill_tokens
|
|
|
|
metadata_builder._num_decodes = 0
|
|
metadata_builder._num_prefills = num_reqs
|
|
metadata_builder._num_decode_tokens = 0
|
|
metadata_builder._num_prefill_tokens = total_num_scheduled_tokens
|
|
attn_metadata_i = (
|
|
metadata_builder.build(
|
|
common_prefix_len=common_prefix_len,
|
|
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata
|
|
if req_offset > 0:
|
|
metadata_builder.block_table.block_table = origin_block_table
|
|
metadata_builder.block_table.slot_mapping = origin_slot_mapping
|
|
metadata_builder.block_table.slot_mapping_cpu = origin_slot_map_cpu
|
|
|
|
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
|
|
metadata_builder._num_decodes = _num_decodes_record
|
|
metadata_builder._num_prefills = _num_prefills_record
|
|
metadata_builder._num_decode_tokens = _num_decode_tokens_record
|
|
metadata_builder._num_prefill_tokens = _num_prefill_tokens_record
|
|
|
|
for layer_name in kv_cache_group_spec.layer_names:
|
|
attn_metadata[layer_name] = attn_metadata_i
|
|
|
|
return attn_metadata
|
|
|
|
def pad_num_input_tokens(self, scheduler_output):
|
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
if (self.use_cuda_graph
|
|
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
|
# Use piecewise CUDA graphs.
|
|
# Add padding to the batch size.
|
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
|
num_scheduled_tokens)
|
|
else:
|
|
# Eager mode.
|
|
# Pad tokens to multiple of tensor_parallel_size when
|
|
# enabled collective fusion for SP
|
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
|
if self.vllm_config.compilation_config.pass_config. \
|
|
enable_sequence_parallelism and tp_size > 1:
|
|
from vllm.utils import round_up
|
|
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
|
else:
|
|
num_input_tokens = num_scheduled_tokens
|
|
|
|
# Padding for DP
|
|
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
|
num_input_tokens += num_pad
|
|
return num_input_tokens, num_tokens_across_dp
|
|
|
|
def tbo_split_and_execute_model(
|
|
runner,
|
|
attn_metadata,
|
|
num_input_tokens,
|
|
num_tokens_across_dp,
|
|
input_ids,
|
|
positions,
|
|
inputs_embeds,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
skip_cuda_graphs: bool = True,
|
|
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
|
use_tbo = False
|
|
if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \
|
|
runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode
|
|
use_tbo = False
|
|
else:
|
|
if len(scheduler_output.num_scheduled_tokens) > 1 and num_input_tokens > envs.VLLM_TBO_MIN_TOKENS:
|
|
split_scheduler_output(runner, scheduler_output)
|
|
use_tbo = True
|
|
if use_tbo:
|
|
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
|
|
num_input_tokens_right = num_input_tokens - num_input_tokens_left
|
|
|
|
attn_metadata_left = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_left, input_split.req_ids_left, 0)
|
|
attn_metadata_right = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_right, input_split.req_ids_right, input_split.req_num_left)
|
|
|
|
with set_forward_context(attn_metadata,
|
|
runner.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
skip_cuda_graphs=True):
|
|
runner.maybe_setup_kv_connector(scheduler_output)
|
|
|
|
model_output = tbo_model_executable_v1(
|
|
runner,
|
|
attn_metadata_left,
|
|
attn_metadata_right,
|
|
num_input_tokens_left,
|
|
num_input_tokens_right,
|
|
num_tokens_across_dp,
|
|
input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds)
|
|
|
|
runner.maybe_wait_for_kv_save()
|
|
finished_sending, finished_recving = (
|
|
runner.get_finished_kv_transfers(scheduler_output))
|
|
#finished_sending, finished_recving = None, None
|
|
else:
|
|
# Run the decoder.
|
|
# Use persistent buffers for CUDA graphs.
|
|
envs.VLLM_ENABLE_TBO = False
|
|
with set_forward_context(attn_metadata,
|
|
runner.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
skip_cuda_graphs=skip_cuda_graphs):
|
|
runner.maybe_setup_kv_connector(scheduler_output)
|
|
|
|
model_output = runner.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
runner.maybe_wait_for_kv_save()
|
|
finished_sending, finished_recving = (
|
|
runner.get_finished_kv_transfers(scheduler_output))
|
|
envs.VLLM_ENABLE_TBO = True
|
|
return model_output, finished_sending, finished_recving |