Files
enginex-hygon-vllm/vllm/two_batch_overlap/two_batch_overlap.py
2026-01-09 15:09:53 +08:00

482 lines
21 KiB
Python

import gc
import os
import queue
import threading
from typing import List, Optional, Tuple
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_capture_attention_metadata, split_model_input
from vllm.logger import init_logger
from vllm.profiler.prof import profile
from vllm import envs
from vllm.utils import weak_ref_tensor
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import is_enable_tbo_v1, tbo_all_reduce_v1
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
logger = init_logger(__name__)
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
def __init__(self):
global tbo_step_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
self.states_right_queue = queue.Queue()
self.left_thread = None
self.right_thread = None
self.left_tid = 0
self.right_tid = 0
self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
self.model_input_left_queue.empty()
self.model_input_right_queue.empty()
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self):
self.left_thread.join()
self.left_thread = None
self.right_thread.join()
self.right_thread = None
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
model_input = queue.get()
profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid)
model_kwargs = None
intermediate_tensors = None
if is_left_thread:
model_kwargs = self.model_kwargs_left
intermediate_tensors = self.intermediate_tensors_left
else:
model_kwargs = self.model_kwargs_right
intermediate_tensors = self.intermediate_tensors_right
hidden_or_intermediate_states = None
if self.tbo_in_capture:
if is_left_thread:
attn_metadata = self.attn_metadata_left
input_tokens = self.input_tokens_left
input_positions = self.split_input_positions[0]
else:
attn_metadata = self.attn_metadata_right
input_tokens = self.input_tokens_right
input_positions = self.split_input_positions[1]
with set_forward_context(attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
input_ids=input_tokens,
positions=input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device),
**model_kwargs,
)
elif model_input != None:
with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device),
**self.seqlen_agnostic_kwargs,
**model_kwargs,
)
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(hidden_or_intermediate_states)
else:
self.states_right_queue.put(hidden_or_intermediate_states)
profile.ProfRangePop()
def tbo_thread_synchronize(self, tid):
if tid == self.left_tid:
if not self.left_first:
self.sem_right.release()
self.left_first = False
profile.ProfRangePop()
self.sem_left.acquire()
profile.ProfRangePush('left')
return self.event_left_c2t, self.event_left_t2c
else:
self.sem_left.release()
profile.ProfRangePop()
self.sem_right.acquire()
profile.ProfRangePush('right')
return self.event_right_c2t, self.event_right_t2c
def set_model_input(self,
model_input_left,
model_input_right,
vllm_config,
virtual_engine,
model_executable,
intermediate_tensors_left,
intermediate_tensors_right,
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
model_kwargs_left,
model_kwargs_right):
self.vllm_config = vllm_config
self.virtual_engine = virtual_engine
self.model_executable = model_executable
self.intermediate_tensors_left = intermediate_tensors_left
self.intermediate_tensors_right = intermediate_tensors_right
self.multi_modal_kwargs = multi_modal_kwargs
self.self_device = self_device
self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
self.model_kwargs_left = model_kwargs_left
self.model_kwargs_right = model_kwargs_right
self.model_input_left_queue.put(model_input_left)
self.model_input_right_queue.put(model_input_right)
def set_capture_model_input(self,
input_tokens_left,
input_tokens_right,
split_input_positions,
vllm_config,
virtual_engine,
runner_model,
runner_device,
intermediate_tensors_left,
intermediate_tensors_right,
model_kwargs_left,
model_kwargs_right,
attn_metadata_left,
attn_metadata_right):
self.input_tokens_left = input_tokens_left
self.input_tokens_right = input_tokens_right
self.split_input_positions = split_input_positions
self.vllm_config = vllm_config
self.virtual_engine = virtual_engine
self.model_executable = runner_model
self.self_device = runner_device
self.intermediate_tensors_left = intermediate_tensors_left
self.intermediate_tensors_right = intermediate_tensors_right
self.model_kwargs_left = model_kwargs_left
self.model_kwargs_right = model_kwargs_right
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
def get_model_output(self):
states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get()
return states_left, states_right
tbo_obj = None
def init_two_batch_overlap():
global tbo_obj
if tbo_obj == None:
tbo_obj = TwoBatchOverlap()
tbo_obj.init_tbo_thread()
def tbo_all_reduce(obj):
if is_enable_tbo_v1():
return tbo_all_reduce_v1(obj)
if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
tid = threading.get_ident()
if not tbo_one_stream:
if tid == tbo_obj.left_tid:
event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
else:
event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
event_c2t.record()
with torch.cuda.stream(all_reduce_stream):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(obj)
event_t2c.record()
tbo_obj.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c)
else:
output = tensor_model_parallel_all_reduce(obj)
tbo_obj.tbo_thread_synchronize(tid)
return output
return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors):
output_map = {}
for key in states_left.tensors:
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
output = IntermediateTensors(output_map)
else:
output = torch.concat([states_left, states_right], dim=0)
return output
def tbo_model_executable(
model_input,
vllm_config,
virtual_engine,
model_executable,
intermediate_tensors,
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
model_kwargs,
):
is_support = is_supported_attention_metadata(model_input.attn_metadata)
if not is_support:
logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
batch_size = len(model_input.attn_metadata.seq_lens)
is_decode_tbo_invalid = not model_input.is_prompt and (
envs.VLLM_TBO_DECODE_BS < 2 or
batch_size < envs.VLLM_TBO_DECODE_BS or
model_input.attn_metadata.use_cuda_graph)
if batch_size == 1 or \
is_decode_tbo_invalid or \
not is_support:
with set_forward_context(model_input.attn_metadata,
vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self_device),
**seqlen_agnostic_kwargs,
**model_kwargs,
)
return hidden_or_intermediate_states
profile.ProfRangePush('tbo_model_executable')
init_two_batch_overlap()
tbo_obj.tbo_running = True
tbo_obj.left_first = True
batch_size_left = int(batch_size / 2)
batch_size_right = batch_size_left
if batch_size % 2 == 1:
batch_size_right += 1
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
model_kwargs_left = model_kwargs.copy()
model_kwargs_right = model_kwargs.copy()
intermediate_tensors_left = None
intermediate_tensors_right = None
if "previous_hidden_states" in model_kwargs:
previous_hidden_states = model_kwargs["previous_hidden_states"]
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
if intermediate_tensors != None:
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
intermediate_tensors_left = {}
intermediate_tensors_right = {}
for key in intermediate_tensors.tensors:
split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
intermediate_tensors_left[key] = split_intermediate_tensors[0]
intermediate_tensors_right[key] = split_intermediate_tensors[1]
intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj.step_event)
tbo_obj.set_model_input(model_input_left,
model_input_right,
vllm_config,
virtual_engine,
model_executable,
intermediate_tensors_left,
intermediate_tensors_right,
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
model_kwargs_left,
model_kwargs_right)
states_left, states_right = tbo_obj.get_model_output()
hidden_or_intermediate_states = merge_model_output(states_left, states_right)
tbo_obj.tbo_running = False
tbo_obj.step_event.record()
tbo_obj.finish_thread()
current_stream.wait_event(tbo_obj.step_event)
profile.ProfRangePop()
return hidden_or_intermediate_states
def _run_once(vllm_config, virtual_engine,
runner,
self_device,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
attn_metadata: AttentionMetadata,
stream: torch.cuda.Stream,
**kwargs):
global tbo_step_stream
stream_back = tbo_step_stream
tbo_step_stream = stream
init_two_batch_overlap()
tbo_obj.left_first = True
decode_batch_size = input_ids.shape[0]
batch_size_left = int(decode_batch_size / 2)
batch_size_right = decode_batch_size - batch_size_left
query_tokens_split = [batch_size_left, batch_size_right]
input_tokens_left, input_tokens_right = torch.split(input_ids, query_tokens_split, dim=0)
split_input_positions = torch.split(positions, query_tokens_split, dim=0)
model_kwargs_left = kwargs.copy()
model_kwargs_right = kwargs.copy()
intermediate_tensors_left = None
intermediate_tensors_right = None
if "previous_hidden_states" in kwargs:
previous_hidden_states = kwargs["previous_hidden_states"]
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
if intermediate_inputs != None:
query_tokens_split = [batch_size_left, batch_size_right]
intermediate_tensors_left = {}
intermediate_tensors_right = {}
for key in intermediate_inputs.tensors:
split_intermediate_tensors = torch.split(intermediate_inputs.tensors[key], query_tokens_split, dim=0)
intermediate_tensors_left[key] = split_intermediate_tensors[0]
intermediate_tensors_right[key] = split_intermediate_tensors[1]
intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
attn_metadata_left, attn_metadata_right = split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right)
tbo_obj.tbo_running = True
tbo_obj.tbo_in_capture = True
tbo_obj.set_capture_model_input(input_tokens_left,
input_tokens_right,
split_input_positions,
vllm_config,
virtual_engine,
runner.model,
self_device,
intermediate_tensors_left,
intermediate_tensors_right,
model_kwargs_left,
model_kwargs_right,
attn_metadata_left,
attn_metadata_right)
states_left, states_right = tbo_obj.get_model_output()
output_hidden_or_intermediate_states = merge_model_output(states_left, states_right)
tbo_obj.tbo_in_capture = False
tbo_obj.tbo_running = False
tbo_obj.finish_thread()
tbo_step_stream = stream_back
return output_hidden_or_intermediate_states
def tbo_capture(vllm_config, virtual_engine, _NUM_WARMUP_ITERS,
runner,
self_device,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs):
for i in range(_NUM_WARMUP_ITERS):
_run_once(vllm_config,
virtual_engine,
runner,
self_device,
input_ids,
positions,
intermediate_inputs,
attn_metadata,
torch.cuda.current_stream(),
**kwargs)
torch.cuda.synchronize()
runner._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(runner._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = _run_once(vllm_config,
virtual_engine,
runner,
self_device,
input_ids,
positions,
intermediate_inputs,
attn_metadata,
torch.cuda.current_stream(),
**kwargs)
if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
hidden_or_intermediate_states = weak_ref_tensor(
output_hidden_or_intermediate_states)
elif isinstance(output_hidden_or_intermediate_states,
IntermediateTensors):
hidden_or_intermediate_states = IntermediateTensors(
tensors={
key: weak_ref_tensor(value)
for key, value in
output_hidden_or_intermediate_states.tensors.items()
})
del output_hidden_or_intermediate_states
# make sure `output_hidden_or_intermediate_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize()
# Save the input and output buffers.
runner.input_buffers = {
"input_ids":
input_ids,
"positions":
positions,
"kv_caches":
kv_caches,
**runner.attn_state.get_graph_input_buffers(
attn_metadata, runner._is_encoder_decoder_model),
**kwargs,
}
if intermediate_inputs is not None:
runner.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
runner.output_buffers = {
"hidden_states": hidden_or_intermediate_states
}
else:
runner.output_buffers = hidden_or_intermediate_states