import gc import os import queue import threading from typing import List, Optional, Tuple import torch from vllm.attention.backends.abstract import AttentionMetadata from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.forward_context import set_forward_context from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import IntermediateTensors from vllm.two_batch_overlap.forward_context import init_tbo_forward_context from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_capture_attention_metadata, split_model_input from vllm.logger import init_logger from vllm.profiler.prof import profile from vllm import envs from vllm.utils import weak_ref_tensor from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import is_enable_tbo_v1, tbo_all_reduce_v1 tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1' logger = init_logger(__name__) tbo_step_stream = None all_reduce_stream = None class TwoBatchOverlap(): def __init__(self): global tbo_step_stream global all_reduce_stream self.model_input_left_queue = queue.Queue() self.model_input_right_queue = queue.Queue() self.states_left_queue = queue.Queue() self.states_right_queue = queue.Queue() self.left_thread = None self.right_thread = None self.left_tid = 0 self.right_tid = 0 self.sem_left = threading.Semaphore(0) self.sem_right = threading.Semaphore(0) self.left_first = False self.tbo_running = False self.tbo_in_capture = False if tbo_step_stream == None: tbo_step_stream = torch.cuda.Stream() all_reduce_stream = torch.cuda.Stream() self.step_event = torch.cuda.Event(enable_timing=False) self.event_left_c2t = torch.cuda.Event(enable_timing=False) self.event_right_c2t = torch.cuda.Event(enable_timing=False) self.event_left_t2c = torch.cuda.Event(enable_timing=False) self.event_right_t2c = torch.cuda.Event(enable_timing=False) def init_tbo_thread(self): self.model_input_left_queue.empty() self.model_input_right_queue.empty() self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,)) self.left_thread.start() self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.right_thread.start() if get_tp_group().rank == 0: logger.info('tbo:two batch overlap start') def finish_thread(self): self.left_thread.join() self.left_thread = None self.right_thread.join() self.right_thread = None @torch.inference_mode() def thread_two_batch_overlap(self, queue): is_left_thread = False tid = threading.get_ident() if queue == self.model_input_left_queue: self.left_tid = tid is_left_thread = True init_tbo_forward_context(True, self.left_tid) else: self.right_tid = tid init_tbo_forward_context(False, self.right_tid) with torch.cuda.stream(tbo_step_stream): model_input = queue.get() profile.ProfRangePush('start') self.tbo_thread_synchronize(tid) model_kwargs = None intermediate_tensors = None if is_left_thread: model_kwargs = self.model_kwargs_left intermediate_tensors = self.intermediate_tensors_left else: model_kwargs = self.model_kwargs_right intermediate_tensors = self.intermediate_tensors_right hidden_or_intermediate_states = None if self.tbo_in_capture: if is_left_thread: attn_metadata = self.attn_metadata_left input_tokens = self.input_tokens_left input_positions = self.split_input_positions[0] else: attn_metadata = self.attn_metadata_right input_tokens = self.input_tokens_right input_positions = self.split_input_positions[1] with set_forward_context(attn_metadata, self.vllm_config, self.virtual_engine): hidden_or_intermediate_states = self.model_executable( input_ids=input_tokens, positions=input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs, device=self.self_device), **model_kwargs, ) elif model_input != None: with set_forward_context(model_input.attn_metadata, self.vllm_config, self.virtual_engine): hidden_or_intermediate_states = self.model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs, device=self.self_device), **self.seqlen_agnostic_kwargs, **model_kwargs, ) if is_left_thread: self.sem_right.release() self.states_left_queue.put(hidden_or_intermediate_states) else: self.states_right_queue.put(hidden_or_intermediate_states) profile.ProfRangePop() def tbo_thread_synchronize(self, tid): if tid == self.left_tid: if not self.left_first: self.sem_right.release() self.left_first = False profile.ProfRangePop() self.sem_left.acquire() profile.ProfRangePush('left') return self.event_left_c2t, self.event_left_t2c else: self.sem_left.release() profile.ProfRangePop() self.sem_right.acquire() profile.ProfRangePush('right') return self.event_right_c2t, self.event_right_t2c def set_model_input(self, model_input_left, model_input_right, vllm_config, virtual_engine, model_executable, intermediate_tensors_left, intermediate_tensors_right, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs_left, model_kwargs_right): self.vllm_config = vllm_config self.virtual_engine = virtual_engine self.model_executable = model_executable self.intermediate_tensors_left = intermediate_tensors_left self.intermediate_tensors_right = intermediate_tensors_right self.multi_modal_kwargs = multi_modal_kwargs self.self_device = self_device self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs self.model_kwargs_left = model_kwargs_left self.model_kwargs_right = model_kwargs_right self.model_input_left_queue.put(model_input_left) self.model_input_right_queue.put(model_input_right) def set_capture_model_input(self, input_tokens_left, input_tokens_right, split_input_positions, vllm_config, virtual_engine, runner_model, runner_device, intermediate_tensors_left, intermediate_tensors_right, model_kwargs_left, model_kwargs_right, attn_metadata_left, attn_metadata_right): self.input_tokens_left = input_tokens_left self.input_tokens_right = input_tokens_right self.split_input_positions = split_input_positions self.vllm_config = vllm_config self.virtual_engine = virtual_engine self.model_executable = runner_model self.self_device = runner_device self.intermediate_tensors_left = intermediate_tensors_left self.intermediate_tensors_right = intermediate_tensors_right self.model_kwargs_left = model_kwargs_left self.model_kwargs_right = model_kwargs_right self.attn_metadata_left = attn_metadata_left self.attn_metadata_right = attn_metadata_right self.model_input_left_queue.put(None) self.model_input_right_queue.put(None) def get_model_output(self): states_left = self.states_left_queue.get() states_right = self.states_right_queue.get() return states_left, states_right tbo_obj = None def init_two_batch_overlap(): global tbo_obj if tbo_obj == None: tbo_obj = TwoBatchOverlap() tbo_obj.init_tbo_thread() def tbo_all_reduce(obj): if is_enable_tbo_v1(): return tbo_all_reduce_v1(obj) if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running: tid = threading.get_ident() if not tbo_one_stream: if tid == tbo_obj.left_tid: event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c else: event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c event_c2t.record() with torch.cuda.stream(all_reduce_stream): all_reduce_stream.wait_event(event_c2t) output = tensor_model_parallel_all_reduce(obj) event_t2c.record() tbo_obj.tbo_thread_synchronize(tid) tbo_step_stream.wait_event(event_t2c) else: output = tensor_model_parallel_all_reduce(obj) tbo_obj.tbo_thread_synchronize(tid) return output return tensor_model_parallel_all_reduce(obj) def merge_model_output(states_left, states_right): if isinstance(states_left, IntermediateTensors): output_map = {} for key in states_left.tensors: output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0) output = IntermediateTensors(output_map) else: output = torch.concat([states_left, states_right], dim=0) return output def tbo_model_executable( model_input, vllm_config, virtual_engine, model_executable, intermediate_tensors, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs, ): is_support = is_supported_attention_metadata(model_input.attn_metadata) if not is_support: logger.info("tbo:not surpport yet ", type(model_input.attn_metadata)) batch_size = len(model_input.attn_metadata.seq_lens) is_decode_tbo_invalid = not model_input.is_prompt and ( envs.VLLM_TBO_DECODE_BS < 2 or batch_size < envs.VLLM_TBO_DECODE_BS or model_input.attn_metadata.use_cuda_graph) if batch_size == 1 or \ is_decode_tbo_invalid or \ not is_support: with set_forward_context(model_input.attn_metadata, vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self_device), **seqlen_agnostic_kwargs, **model_kwargs, ) return hidden_or_intermediate_states profile.ProfRangePush('tbo_model_executable') init_two_batch_overlap() tbo_obj.tbo_running = True tbo_obj.left_first = True batch_size_left = int(batch_size / 2) batch_size_right = batch_size_left if batch_size % 2 == 1: batch_size_right += 1 model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right) model_kwargs_left = model_kwargs.copy() model_kwargs_right = model_kwargs.copy() intermediate_tensors_left = None intermediate_tensors_right = None if "previous_hidden_states" in model_kwargs: previous_hidden_states = model_kwargs["previous_hidden_states"] query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])] split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0) model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0] model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1] if intermediate_tensors != None: query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])] intermediate_tensors_left = {} intermediate_tensors_right = {} for key in intermediate_tensors.tensors: split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0) intermediate_tensors_left[key] = split_intermediate_tensors[0] intermediate_tensors_right[key] = split_intermediate_tensors[1] intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left) intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right) tbo_obj.step_event.record() current_stream = torch.cuda.current_stream() with torch.cuda.stream(tbo_step_stream): tbo_step_stream.wait_event(tbo_obj.step_event) tbo_obj.set_model_input(model_input_left, model_input_right, vllm_config, virtual_engine, model_executable, intermediate_tensors_left, intermediate_tensors_right, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs_left, model_kwargs_right) states_left, states_right = tbo_obj.get_model_output() hidden_or_intermediate_states = merge_model_output(states_left, states_right) tbo_obj.tbo_running = False tbo_obj.step_event.record() tbo_obj.finish_thread() current_stream.wait_event(tbo_obj.step_event) profile.ProfRangePop() return hidden_or_intermediate_states def _run_once(vllm_config, virtual_engine, runner, self_device, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_inputs: Optional[IntermediateTensors], attn_metadata: AttentionMetadata, stream: torch.cuda.Stream, **kwargs): global tbo_step_stream stream_back = tbo_step_stream tbo_step_stream = stream init_two_batch_overlap() tbo_obj.left_first = True decode_batch_size = input_ids.shape[0] batch_size_left = int(decode_batch_size / 2) batch_size_right = decode_batch_size - batch_size_left query_tokens_split = [batch_size_left, batch_size_right] input_tokens_left, input_tokens_right = torch.split(input_ids, query_tokens_split, dim=0) split_input_positions = torch.split(positions, query_tokens_split, dim=0) model_kwargs_left = kwargs.copy() model_kwargs_right = kwargs.copy() intermediate_tensors_left = None intermediate_tensors_right = None if "previous_hidden_states" in kwargs: previous_hidden_states = kwargs["previous_hidden_states"] split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0) model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0] model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1] if intermediate_inputs != None: query_tokens_split = [batch_size_left, batch_size_right] intermediate_tensors_left = {} intermediate_tensors_right = {} for key in intermediate_inputs.tensors: split_intermediate_tensors = torch.split(intermediate_inputs.tensors[key], query_tokens_split, dim=0) intermediate_tensors_left[key] = split_intermediate_tensors[0] intermediate_tensors_right[key] = split_intermediate_tensors[1] intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left) intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right) attn_metadata_left, attn_metadata_right = split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right) tbo_obj.tbo_running = True tbo_obj.tbo_in_capture = True tbo_obj.set_capture_model_input(input_tokens_left, input_tokens_right, split_input_positions, vllm_config, virtual_engine, runner.model, self_device, intermediate_tensors_left, intermediate_tensors_right, model_kwargs_left, model_kwargs_right, attn_metadata_left, attn_metadata_right) states_left, states_right = tbo_obj.get_model_output() output_hidden_or_intermediate_states = merge_model_output(states_left, states_right) tbo_obj.tbo_in_capture = False tbo_obj.tbo_running = False tbo_obj.finish_thread() tbo_step_stream = stream_back return output_hidden_or_intermediate_states def tbo_capture(vllm_config, virtual_engine, _NUM_WARMUP_ITERS, runner, self_device, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_inputs: Optional[IntermediateTensors], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool: Optional[Tuple[int, int]], stream: torch.cuda.Stream, **kwargs): for i in range(_NUM_WARMUP_ITERS): _run_once(vllm_config, virtual_engine, runner, self_device, input_ids, positions, intermediate_inputs, attn_metadata, torch.cuda.current_stream(), **kwargs) torch.cuda.synchronize() runner._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(runner._graph, pool=memory_pool, stream=stream): output_hidden_or_intermediate_states = _run_once(vllm_config, virtual_engine, runner, self_device, input_ids, positions, intermediate_inputs, attn_metadata, torch.cuda.current_stream(), **kwargs) if isinstance(output_hidden_or_intermediate_states, torch.Tensor): hidden_or_intermediate_states = weak_ref_tensor( output_hidden_or_intermediate_states) elif isinstance(output_hidden_or_intermediate_states, IntermediateTensors): hidden_or_intermediate_states = IntermediateTensors( tensors={ key: weak_ref_tensor(value) for key, value in output_hidden_or_intermediate_states.tensors.items() }) del output_hidden_or_intermediate_states # make sure `output_hidden_or_intermediate_states` is deleted # in the graph's memory pool gc.collect() torch.cuda.synchronize() # Save the input and output buffers. runner.input_buffers = { "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, **runner.attn_state.get_graph_input_buffers( attn_metadata, runner._is_encoder_decoder_model), **kwargs, } if intermediate_inputs is not None: runner.input_buffers.update(intermediate_inputs.tensors) if get_pp_group().is_last_rank: runner.output_buffers = { "hidden_states": hidden_or_intermediate_states } else: runner.output_buffers = hidden_or_intermediate_states