init src 0.9.2

This commit is contained in:
2026-01-09 15:09:53 +08:00
parent 0eb2c0a4b3
commit 41d98d4359
1438 changed files with 417605 additions and 683 deletions

View File

@@ -0,0 +1,335 @@
from typing import Any, Optional, Union
import numpy as np
import torch
from vllm import envs
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1
from vllm.utils import async_tensor_h2d
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.block_table import BlockTable
class TBOModelInputSplit():
def __init__(self):
self.req_ids_left = []
self.req_ids_right = []
self.req_num_left = 0
self.req_num_right = 0
self.scheduler_output_left = None
self.scheduler_output_right = None
self.query_start_loc_right = None
input_split = TBOModelInputSplit()
def split_scheduler_output(runner, scheduler_output:SchedulerOutput):
split_tokens = scheduler_output.total_num_scheduled_tokens // 2
req_ids = runner.input_batch.req_ids
tokens_counter = 0
min_idx = -1
min_counter = 0
for i, id in enumerate(req_ids):
tokens_counter += scheduler_output.num_scheduled_tokens[id]
diff = abs(tokens_counter - split_tokens)
if min_idx == -1 or diff < min_counter:
min_idx = i
min_counter = diff
if tokens_counter > split_tokens or diff == 0:
break
input_split.req_num_left = min_idx + 1
if input_split.req_num_left == len(req_ids):
input_split.req_num_left = input_split.req_num_left - 1
input_split.req_ids_left = req_ids[:input_split.req_num_left]
input_split.req_ids_right = req_ids[input_split.req_num_left:]
input_split.req_num_right = len(req_ids) - input_split.req_num_left
new_req_data_left = []
new_req_data_right = []
cached_reqs_left = []
cached_reqs_right = []
num_scheduled_tokens_left = {}
num_scheduled_tokens_right = {}
total_num_scheduled_tokens_left = 0
total_num_scheduled_tokens_right = 0
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in input_split.req_ids_left:
new_req_data_left.append(new_req)
else:
new_req_data_right.append(new_req)
cached_reqs_left = CachedRequestData.make_empty()
cached_reqs_right = CachedRequestData.make_empty()
for req_idx, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids):
if req_id in input_split.req_ids_left:
cached_reqs_left.req_ids.append(req_id)
cached_reqs_left.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx])
if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0:
cached_reqs_left.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx])
cached_reqs_left.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx])
cached_reqs_left.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx])
else:
cached_reqs_right.req_ids.append(req_id)
cached_reqs_right.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx])
if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0:
cached_reqs_right.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx])
cached_reqs_right.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx])
cached_reqs_right.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx])
for key, value in scheduler_output.num_scheduled_tokens.items():
if key in input_split.req_ids_left:
num_scheduled_tokens_left[key] = value
total_num_scheduled_tokens_left += value
else:
num_scheduled_tokens_right[key] = value
total_num_scheduled_tokens_right += value
input_split.scheduler_output_left = SchedulerOutput(
scheduled_new_reqs=new_req_data_left,
scheduled_cached_reqs=cached_reqs_left,
num_scheduled_tokens=num_scheduled_tokens_left,
total_num_scheduled_tokens=total_num_scheduled_tokens_left,
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=scheduler_output.finished_req_ids,
free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
structured_output_request_ids=scheduler_output.structured_output_request_ids,
grammar_bitmask=scheduler_output.grammar_bitmask,
)
input_split.scheduler_output_right = SchedulerOutput(
scheduled_new_reqs=new_req_data_right,
scheduled_cached_reqs=cached_reqs_right,
num_scheduled_tokens=num_scheduled_tokens_right,
total_num_scheduled_tokens=total_num_scheduled_tokens_right,
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=scheduler_output.finished_req_ids,
free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
structured_output_request_ids=scheduler_output.structured_output_request_ids,
grammar_bitmask=scheduler_output.grammar_bitmask,
)
def prepare_tbo_atten_metadata(
runner,
scheduler_output: "SchedulerOutput",
req_ids,
req_offset
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = len(req_ids)
assert num_reqs > 0
seq_len_offset = req_offset
# Get the number of scheduled tokens for each request.
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
if req_offset > 0: #right
if input_split.query_start_loc_right == None:
# TODO: create when system init
input_split.query_start_loc_right = torch.zeros(runner.max_num_reqs + 1,
dtype=torch.int32,
device=runner.device)
cu_num_tokens, arange = runner._get_cumsum_and_arange(
num_scheduled_tokens)
# Prepare the attention metadata.
runner.query_start_loc_np[0] = 0
runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
input_split.query_start_loc_right[0: num_reqs + 1].copy_(
runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
input_split.query_start_loc_right[num_reqs + 1:].fill_(
runner.query_start_loc_cpu[num_reqs].item())
query_start_loc = input_split.query_start_loc_right[: num_reqs + 1]
else:
query_start_loc = runner.query_start_loc[:num_reqs + 1]
seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
seq_lens=seq_lens,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens)
attn_metadata: dict[str, Any] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
runner.kv_cache_config.kv_cache_groups):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
metadata_builder = runner.attn_metadata_builders[kv_cache_group_id]
if runner.cascade_attn_enabled:
common_prefix_len = runner._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
metadata_builder,
)
if req_offset > 0:
origin_block_table = metadata_builder.block_table.block_table
metadata_builder.block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = metadata_builder.block_table.slot_mapping
metadata_builder.block_table.slot_mapping = \
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:]
origin_slot_map_cpu = metadata_builder.block_table.slot_mapping_cpu
metadata_builder.block_table.slot_mapping_cpu = \
origin_slot_map_cpu[input_split.scheduler_output_left.total_num_scheduled_tokens:]
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
_num_decodes_record = metadata_builder._num_decodes
_num_prefills_record = metadata_builder._num_prefills
_num_decode_tokens_record = metadata_builder._num_decode_tokens
_num_prefill_tokens_record = metadata_builder._num_prefill_tokens
metadata_builder._num_decodes = 0
metadata_builder._num_prefills = num_reqs
metadata_builder._num_decode_tokens = 0
metadata_builder._num_prefill_tokens = total_num_scheduled_tokens
attn_metadata_i = (
metadata_builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata
if req_offset > 0:
metadata_builder.block_table.block_table = origin_block_table
metadata_builder.block_table.slot_mapping = origin_slot_mapping
metadata_builder.block_table.slot_mapping_cpu = origin_slot_map_cpu
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
metadata_builder._num_decodes = _num_decodes_record
metadata_builder._num_prefills = _num_prefills_record
metadata_builder._num_decode_tokens = _num_decode_tokens_record
metadata_builder._num_prefill_tokens = _num_prefill_tokens_record
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
return attn_metadata
def pad_num_input_tokens(self, scheduler_output):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
return num_input_tokens, num_tokens_across_dp
def tbo_split_and_execute_model(
runner,
attn_metadata,
num_input_tokens,
num_tokens_across_dp,
input_ids,
positions,
inputs_embeds,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
skip_cuda_graphs: bool = True,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
use_tbo = False
if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \
runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode
use_tbo = False
else:
if len(scheduler_output.num_scheduled_tokens) > 1 and num_input_tokens > envs.VLLM_TBO_MIN_TOKENS:
split_scheduler_output(runner, scheduler_output)
use_tbo = True
if use_tbo:
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
num_input_tokens_right = num_input_tokens - num_input_tokens_left
attn_metadata_left = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_left, input_split.req_ids_left, 0)
attn_metadata_right = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_right, input_split.req_ids_right, input_split.req_num_left)
with set_forward_context(attn_metadata,
runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=True):
runner.maybe_setup_kv_connector(scheduler_output)
model_output = tbo_model_executable_v1(
runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
num_tokens_across_dp,
input_ids,
positions,
intermediate_tensors,
inputs_embeds)
runner.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
runner.get_finished_kv_transfers(scheduler_output))
#finished_sending, finished_recving = None, None
else:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
envs.VLLM_ENABLE_TBO = False
with set_forward_context(attn_metadata,
runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs):
runner.maybe_setup_kv_connector(scheduler_output)
model_output = runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
runner.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
runner.get_finished_kv_transfers(scheduler_output))
envs.VLLM_ENABLE_TBO = True
return model_output, finished_sending, finished_recving

View File

@@ -0,0 +1,235 @@
import os
import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.logger import init_logger
from vllm.profiler.prof import profile
from vllm import envs
logger = init_logger(__name__)
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
def __init__(self):
global tbo_step_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
self.states_right_queue = queue.Queue()
self.left_thread = None
self.right_thread = None
self.left_tid = 0
self.right_tid = 0
self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
self.model_input_left_queue.empty()
self.model_input_right_queue.empty()
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self):
self.left_thread.join()
self.left_thread = None
self.right_thread.join()
self.right_thread = None
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
queue.get()
self.tbo_thread_synchronize(tid)
if is_left_thread:
attn_metadata = self.attn_metadata_left
num_input_tokens = self.num_input_tokens_left
input_ids = self.input_ids_left
positions = self.positions_left
else:
attn_metadata = self.attn_metadata_right
num_input_tokens = self.num_input_tokens_right
input_ids = self.input_ids_right
positions = self.positions_right
model_output = None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=self.intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(model_output)
else:
self.states_right_queue.put(model_output)
def tbo_thread_synchronize(self, tid):
if tid == self.left_tid:
if not self.left_first:
self.sem_right.release()
self.left_first = False
self.sem_left.acquire()
return self.event_left_c2t, self.event_left_t2c
else:
self.sem_left.release()
self.sem_right.acquire()
return self.event_right_c2t, self.event_right_t2c
def set_model_input(self,
model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds):
self.model_runner = model_runner
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
self.num_input_tokens_left = num_input_tokens_left
self.num_input_tokens_right = num_input_tokens_right
self.input_ids_left = input_ids_left
self.input_ids_right = input_ids_right
self.positions_left = positions_left
self.positions_right = positions_right
self.num_tokens_across_dp = num_tokens_across_dp
self.intermediate_tensors = intermediate_tensors
self.inputs_embeds = inputs_embeds
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
def get_model_output(self):
states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get()
return states_left, states_right
tbo_obj_v1 = None
def is_enable_tbo_v1():
global tbo_obj_v1
return tbo_obj_v1 != None
def init_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 == None:
tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread()
def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
else:
event_c2t, event_t2c = tbo_obj_v1.event_right_c2t, tbo_obj_v1.event_right_t2c
event_c2t.record()
with torch.cuda.stream(all_reduce_stream):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(obj)
event_t2c.record()
tbo_obj_v1.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c)
return output
return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors):
output_map = {}
for key in states_left.tensors:
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
output = IntermediateTensors(output_map)
else:
output = torch.concat([states_left, states_right], dim=0)
return output
def tbo_model_executable_v1(
model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
num_tokens_across_dp,
input_ids,
positions,
intermediate_tensors,
inputs_embeds
):
init_two_batch_overlap()
tbo_obj_v1.tbo_running = True
tbo_obj_v1.left_first = True
tbo_obj_v1.step_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj_v1.step_event)
tokens_split = [num_input_tokens_left, num_input_tokens_right]
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0)
positions_left, positions_right = torch.split(positions, tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds)
model_output_left, model_output_right = tbo_obj_v1.get_model_output()
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
tbo_obj_v1.tbo_running = False
tbo_obj_v1.step_event.record()
tbo_obj_v1.finish_thread()
current_stream.wait_event(tbo_obj_v1.step_event)
return hidden_or_intermediate_states