import os import queue import threading import torch from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tp_group from vllm.forward_context import set_forward_context from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import IntermediateTensors from vllm.two_batch_overlap.forward_context import init_tbo_forward_context from vllm.logger import init_logger from vllm.profiler.prof import profile from vllm import envs logger = init_logger(__name__) tbo_step_stream = None all_reduce_stream = None class TwoBatchOverlap(): def __init__(self): global tbo_step_stream global all_reduce_stream self.model_input_left_queue = queue.Queue() self.model_input_right_queue = queue.Queue() self.states_left_queue = queue.Queue() self.states_right_queue = queue.Queue() self.left_thread = None self.right_thread = None self.left_tid = 0 self.right_tid = 0 self.sem_left = threading.Semaphore(0) self.sem_right = threading.Semaphore(0) self.left_first = False self.tbo_running = False self.tbo_in_capture = False if tbo_step_stream == None: tbo_step_stream = torch.cuda.Stream() all_reduce_stream = torch.cuda.Stream() self.step_event = torch.cuda.Event(enable_timing=False) self.event_left_c2t = torch.cuda.Event(enable_timing=False) self.event_right_c2t = torch.cuda.Event(enable_timing=False) self.event_left_t2c = torch.cuda.Event(enable_timing=False) self.event_right_t2c = torch.cuda.Event(enable_timing=False) def init_tbo_thread(self): self.model_input_left_queue.empty() self.model_input_right_queue.empty() self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,)) self.left_thread.start() self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.right_thread.start() if get_tp_group().rank == 0: logger.info('tbo:two batch overlap start') def finish_thread(self): self.left_thread.join() self.left_thread = None self.right_thread.join() self.right_thread = None @torch.inference_mode() def thread_two_batch_overlap(self, queue): is_left_thread = False tid = threading.get_ident() if queue == self.model_input_left_queue: self.left_tid = tid is_left_thread = True init_tbo_forward_context(True, self.left_tid) else: self.right_tid = tid init_tbo_forward_context(False, self.right_tid) with torch.cuda.stream(tbo_step_stream): queue.get() self.tbo_thread_synchronize(tid) if is_left_thread: attn_metadata = self.attn_metadata_left num_input_tokens = self.num_input_tokens_left input_ids = self.input_ids_left positions = self.positions_left else: attn_metadata = self.attn_metadata_right num_input_tokens = self.num_input_tokens_right input_ids = self.input_ids_right positions = self.positions_right model_output = None # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.model_runner.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=self.num_tokens_across_dp, skip_cuda_graphs=True): model_output = self.model_runner.model( input_ids=input_ids, positions=positions, intermediate_tensors=self.intermediate_tensors, inputs_embeds=self.inputs_embeds, ) if is_left_thread: self.sem_right.release() self.states_left_queue.put(model_output) else: self.states_right_queue.put(model_output) def tbo_thread_synchronize(self, tid): if tid == self.left_tid: if not self.left_first: self.sem_right.release() self.left_first = False self.sem_left.acquire() return self.event_left_c2t, self.event_left_t2c else: self.sem_left.release() self.sem_right.acquire() return self.event_right_c2t, self.event_right_t2c def set_model_input(self, model_runner, attn_metadata_left, attn_metadata_right, num_input_tokens_left, num_input_tokens_right, input_ids_left, input_ids_right, positions_left, positions_right, num_tokens_across_dp, intermediate_tensors, inputs_embeds): self.model_runner = model_runner self.attn_metadata_left = attn_metadata_left self.attn_metadata_right = attn_metadata_right self.num_input_tokens_left = num_input_tokens_left self.num_input_tokens_right = num_input_tokens_right self.input_ids_left = input_ids_left self.input_ids_right = input_ids_right self.positions_left = positions_left self.positions_right = positions_right self.num_tokens_across_dp = num_tokens_across_dp self.intermediate_tensors = intermediate_tensors self.inputs_embeds = inputs_embeds self.model_input_left_queue.put(None) self.model_input_right_queue.put(None) def get_model_output(self): states_left = self.states_left_queue.get() states_right = self.states_right_queue.get() return states_left, states_right tbo_obj_v1 = None def is_enable_tbo_v1(): global tbo_obj_v1 return tbo_obj_v1 != None def init_two_batch_overlap(): global tbo_obj_v1 if tbo_obj_v1 == None: tbo_obj_v1 = TwoBatchOverlap() tbo_obj_v1.init_tbo_thread() def tbo_all_reduce_v1(obj): if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running: tid = threading.get_ident() if tid == tbo_obj_v1.left_tid: event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c else: event_c2t, event_t2c = tbo_obj_v1.event_right_c2t, tbo_obj_v1.event_right_t2c event_c2t.record() with torch.cuda.stream(all_reduce_stream): all_reduce_stream.wait_event(event_c2t) output = tensor_model_parallel_all_reduce(obj) event_t2c.record() tbo_obj_v1.tbo_thread_synchronize(tid) tbo_step_stream.wait_event(event_t2c) return output return tensor_model_parallel_all_reduce(obj) def merge_model_output(states_left, states_right): if isinstance(states_left, IntermediateTensors): output_map = {} for key in states_left.tensors: output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0) output = IntermediateTensors(output_map) else: output = torch.concat([states_left, states_right], dim=0) return output def tbo_model_executable_v1( model_runner, attn_metadata_left, attn_metadata_right, num_input_tokens_left, num_input_tokens_right, num_tokens_across_dp, input_ids, positions, intermediate_tensors, inputs_embeds ): init_two_batch_overlap() tbo_obj_v1.tbo_running = True tbo_obj_v1.left_first = True tbo_obj_v1.step_event.record() current_stream = torch.cuda.current_stream() with torch.cuda.stream(tbo_step_stream): tbo_step_stream.wait_event(tbo_obj_v1.step_event) tokens_split = [num_input_tokens_left, num_input_tokens_right] input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0) positions_left, positions_right = torch.split(positions, tokens_split, dim=0) tbo_obj_v1.set_model_input(model_runner, attn_metadata_left, attn_metadata_right, num_input_tokens_left, num_input_tokens_right, input_ids_left, input_ids_right, positions_left, positions_right, num_tokens_across_dp, intermediate_tensors, inputs_embeds) model_output_left, model_output_right = tbo_obj_v1.get_model_output() hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right) tbo_obj_v1.tbo_running = False tbo_obj_v1.step_event.record() tbo_obj_v1.finish_thread() current_stream.wait_event(tbo_obj_v1.step_event) return hidden_or_intermediate_states