from typing import Any, Optional, Union import numpy as np import torch from vllm import envs from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.forward_context import set_forward_context from vllm.sequence import IntermediateTensors from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1 from vllm.utils import async_tensor_h2d from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.block_table import BlockTable class TBOModelInputSplit(): def __init__(self): self.req_ids_left = [] self.req_ids_right = [] self.req_num_left = 0 self.req_num_right = 0 self.scheduler_output_left = None self.scheduler_output_right = None self.query_start_loc_right = None input_split = TBOModelInputSplit() def split_scheduler_output(runner, scheduler_output:SchedulerOutput): split_tokens = scheduler_output.total_num_scheduled_tokens // 2 req_ids = runner.input_batch.req_ids tokens_counter = 0 min_idx = -1 min_counter = 0 for i, id in enumerate(req_ids): tokens_counter += scheduler_output.num_scheduled_tokens[id] diff = abs(tokens_counter - split_tokens) if min_idx == -1 or diff < min_counter: min_idx = i min_counter = diff if tokens_counter > split_tokens or diff == 0: break input_split.req_num_left = min_idx + 1 if input_split.req_num_left == len(req_ids): input_split.req_num_left = input_split.req_num_left - 1 input_split.req_ids_left = req_ids[:input_split.req_num_left] input_split.req_ids_right = req_ids[input_split.req_num_left:] input_split.req_num_right = len(req_ids) - input_split.req_num_left new_req_data_left = [] new_req_data_right = [] cached_reqs_left = [] cached_reqs_right = [] num_scheduled_tokens_left = {} num_scheduled_tokens_right = {} total_num_scheduled_tokens_left = 0 total_num_scheduled_tokens_right = 0 for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in input_split.req_ids_left: new_req_data_left.append(new_req) else: new_req_data_right.append(new_req) cached_reqs_left = CachedRequestData.make_empty() cached_reqs_right = CachedRequestData.make_empty() for req_idx, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids): if req_id in input_split.req_ids_left: cached_reqs_left.req_ids.append(req_id) cached_reqs_left.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx]) if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0: cached_reqs_left.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx]) cached_reqs_left.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx]) cached_reqs_left.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx]) else: cached_reqs_right.req_ids.append(req_id) cached_reqs_right.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx]) if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0: cached_reqs_right.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx]) cached_reqs_right.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx]) cached_reqs_right.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx]) for key, value in scheduler_output.num_scheduled_tokens.items(): if key in input_split.req_ids_left: num_scheduled_tokens_left[key] = value total_num_scheduled_tokens_left += value else: num_scheduled_tokens_right[key] = value total_num_scheduled_tokens_right += value input_split.scheduler_output_left = SchedulerOutput( scheduled_new_reqs=new_req_data_left, scheduled_cached_reqs=cached_reqs_left, num_scheduled_tokens=num_scheduled_tokens_left, total_num_scheduled_tokens=total_num_scheduled_tokens_left, scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=scheduler_output.finished_req_ids, free_encoder_input_ids=scheduler_output.free_encoder_input_ids, structured_output_request_ids=scheduler_output.structured_output_request_ids, grammar_bitmask=scheduler_output.grammar_bitmask, ) input_split.scheduler_output_right = SchedulerOutput( scheduled_new_reqs=new_req_data_right, scheduled_cached_reqs=cached_reqs_right, num_scheduled_tokens=num_scheduled_tokens_right, total_num_scheduled_tokens=total_num_scheduled_tokens_right, scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=scheduler_output.finished_req_ids, free_encoder_input_ids=scheduler_output.free_encoder_input_ids, structured_output_request_ids=scheduler_output.structured_output_request_ids, grammar_bitmask=scheduler_output.grammar_bitmask, ) def prepare_tbo_atten_metadata( runner, scheduler_output: "SchedulerOutput", req_ids, req_offset ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = len(req_ids) assert num_reqs > 0 seq_len_offset = req_offset # Get the number of scheduled tokens for each request. tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) if req_offset > 0: #right if input_split.query_start_loc_right == None: # TODO: create when system init input_split.query_start_loc_right = torch.zeros(runner.max_num_reqs + 1, dtype=torch.int32, device=runner.device) cu_num_tokens, arange = runner._get_cumsum_and_arange( num_scheduled_tokens) # Prepare the attention metadata. runner.query_start_loc_np[0] = 0 runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens input_split.query_start_loc_right[0: num_reqs + 1].copy_( runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that input_split.query_start_loc_right[num_reqs + 1:].fill_( runner.query_start_loc_cpu[num_reqs].item()) query_start_loc = input_split.query_start_loc_right[: num_reqs + 1] else: query_start_loc = runner.query_start_loc[:num_reqs + 1] seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens) attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( runner.kv_cache_config.kv_cache_groups): # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 metadata_builder = runner.attn_metadata_builders[kv_cache_group_id] if runner.cascade_attn_enabled: common_prefix_len = runner._compute_cascade_attn_prefix_len( num_scheduled_tokens, scheduler_output. num_common_prefix_blocks[kv_cache_group_id], kv_cache_group_spec.kv_cache_spec, metadata_builder, ) if req_offset > 0: origin_block_table = metadata_builder.block_table.block_table metadata_builder.block_table.block_table = origin_block_table[req_offset:, :] origin_slot_mapping = metadata_builder.block_table.slot_mapping metadata_builder.block_table.slot_mapping = \ origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:] origin_slot_map_cpu = metadata_builder.block_table.slot_mapping_cpu metadata_builder.block_table.slot_mapping_cpu = \ origin_slot_map_cpu[input_split.scheduler_output_left.total_num_scheduled_tokens:] if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only _num_decodes_record = metadata_builder._num_decodes _num_prefills_record = metadata_builder._num_prefills _num_decode_tokens_record = metadata_builder._num_decode_tokens _num_prefill_tokens_record = metadata_builder._num_prefill_tokens metadata_builder._num_decodes = 0 metadata_builder._num_prefills = num_reqs metadata_builder._num_decode_tokens = 0 metadata_builder._num_prefill_tokens = total_num_scheduled_tokens attn_metadata_i = ( metadata_builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata if req_offset > 0: metadata_builder.block_table.block_table = origin_block_table metadata_builder.block_table.slot_mapping = origin_slot_mapping metadata_builder.block_table.slot_mapping_cpu = origin_slot_map_cpu if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only metadata_builder._num_decodes = _num_decodes_record metadata_builder._num_prefills = _num_prefills_record metadata_builder._num_decode_tokens = _num_decode_tokens_record metadata_builder._num_prefill_tokens = _num_prefill_tokens_record for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i return attn_metadata def pad_num_input_tokens(self, scheduler_output): num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size if self.vllm_config.compilation_config.pass_config. \ enable_sequence_parallelism and tp_size > 1: from vllm.utils import round_up num_input_tokens = round_up(num_scheduled_tokens, tp_size) else: num_input_tokens = num_scheduled_tokens # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad return num_input_tokens, num_tokens_across_dp def tbo_split_and_execute_model( runner, attn_metadata, num_input_tokens, num_tokens_across_dp, input_ids, positions, inputs_embeds, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, skip_cuda_graphs: bool = True, ) -> Union[ModelRunnerOutput, IntermediateTensors]: use_tbo = False if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \ runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode use_tbo = False else: if len(scheduler_output.num_scheduled_tokens) > 1 and num_input_tokens > envs.VLLM_TBO_MIN_TOKENS: split_scheduler_output(runner, scheduler_output) use_tbo = True if use_tbo: num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens num_input_tokens_right = num_input_tokens - num_input_tokens_left attn_metadata_left = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_left, input_split.req_ids_left, 0) attn_metadata_right = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_right, input_split.req_ids_right, input_split.req_num_left) with set_forward_context(attn_metadata, runner.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=True): runner.maybe_setup_kv_connector(scheduler_output) model_output = tbo_model_executable_v1( runner, attn_metadata_left, attn_metadata_right, num_input_tokens_left, num_input_tokens_right, num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds) runner.maybe_wait_for_kv_save() finished_sending, finished_recving = ( runner.get_finished_kv_transfers(scheduler_output)) #finished_sending, finished_recving = None, None else: # Run the decoder. # Use persistent buffers for CUDA graphs. envs.VLLM_ENABLE_TBO = False with set_forward_context(attn_metadata, runner.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs): runner.maybe_setup_kv_connector(scheduler_output) model_output = runner.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) runner.maybe_wait_for_kv_save() finished_sending, finished_recving = ( runner.get_finished_kv_transfers(scheduler_output)) envs.VLLM_ENABLE_TBO = True return model_output, finished_sending, finished_recving