235 lines
9.3 KiB
Python
235 lines
9.3 KiB
Python
import os
|
|
import queue
|
|
import threading
|
|
import torch
|
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
|
from vllm.distributed.parallel_state import get_tp_group
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.multimodal.inputs import MultiModalKwargs
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.profiler.prof import profile
|
|
from vllm import envs
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
tbo_step_stream = None
|
|
all_reduce_stream = None
|
|
|
|
class TwoBatchOverlap():
|
|
def __init__(self):
|
|
global tbo_step_stream
|
|
global all_reduce_stream
|
|
self.model_input_left_queue = queue.Queue()
|
|
self.model_input_right_queue = queue.Queue()
|
|
self.states_left_queue = queue.Queue()
|
|
self.states_right_queue = queue.Queue()
|
|
self.left_thread = None
|
|
self.right_thread = None
|
|
self.left_tid = 0
|
|
self.right_tid = 0
|
|
self.sem_left = threading.Semaphore(0)
|
|
self.sem_right = threading.Semaphore(0)
|
|
self.left_first = False
|
|
self.tbo_running = False
|
|
self.tbo_in_capture = False
|
|
if tbo_step_stream == None:
|
|
tbo_step_stream = torch.cuda.Stream()
|
|
all_reduce_stream = torch.cuda.Stream()
|
|
self.step_event = torch.cuda.Event(enable_timing=False)
|
|
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
|
|
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
|
|
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
|
|
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
|
|
|
|
def init_tbo_thread(self):
|
|
self.model_input_left_queue.empty()
|
|
self.model_input_right_queue.empty()
|
|
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
|
|
self.left_thread.start()
|
|
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
|
|
self.right_thread.start()
|
|
if get_tp_group().rank == 0:
|
|
logger.info('tbo:two batch overlap start')
|
|
|
|
def finish_thread(self):
|
|
self.left_thread.join()
|
|
self.left_thread = None
|
|
self.right_thread.join()
|
|
self.right_thread = None
|
|
|
|
@torch.inference_mode()
|
|
def thread_two_batch_overlap(self, queue):
|
|
is_left_thread = False
|
|
tid = threading.get_ident()
|
|
if queue == self.model_input_left_queue:
|
|
self.left_tid = tid
|
|
is_left_thread = True
|
|
init_tbo_forward_context(True, self.left_tid)
|
|
else:
|
|
self.right_tid = tid
|
|
init_tbo_forward_context(False, self.right_tid)
|
|
with torch.cuda.stream(tbo_step_stream):
|
|
queue.get()
|
|
self.tbo_thread_synchronize(tid)
|
|
if is_left_thread:
|
|
attn_metadata = self.attn_metadata_left
|
|
num_input_tokens = self.num_input_tokens_left
|
|
input_ids = self.input_ids_left
|
|
positions = self.positions_left
|
|
else:
|
|
attn_metadata = self.attn_metadata_right
|
|
num_input_tokens = self.num_input_tokens_right
|
|
input_ids = self.input_ids_right
|
|
positions = self.positions_right
|
|
|
|
model_output = None
|
|
# Run the decoder.
|
|
# Use persistent buffers for CUDA graphs.
|
|
with set_forward_context(attn_metadata,
|
|
self.model_runner.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
num_tokens_across_dp=self.num_tokens_across_dp,
|
|
skip_cuda_graphs=True):
|
|
model_output = self.model_runner.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=self.intermediate_tensors,
|
|
inputs_embeds=self.inputs_embeds,
|
|
)
|
|
if is_left_thread:
|
|
self.sem_right.release()
|
|
self.states_left_queue.put(model_output)
|
|
else:
|
|
self.states_right_queue.put(model_output)
|
|
|
|
def tbo_thread_synchronize(self, tid):
|
|
if tid == self.left_tid:
|
|
if not self.left_first:
|
|
self.sem_right.release()
|
|
self.left_first = False
|
|
self.sem_left.acquire()
|
|
return self.event_left_c2t, self.event_left_t2c
|
|
else:
|
|
self.sem_left.release()
|
|
self.sem_right.acquire()
|
|
return self.event_right_c2t, self.event_right_t2c
|
|
|
|
def set_model_input(self,
|
|
model_runner,
|
|
attn_metadata_left,
|
|
attn_metadata_right,
|
|
num_input_tokens_left,
|
|
num_input_tokens_right,
|
|
input_ids_left,
|
|
input_ids_right,
|
|
positions_left,
|
|
positions_right,
|
|
num_tokens_across_dp,
|
|
intermediate_tensors,
|
|
inputs_embeds):
|
|
self.model_runner = model_runner
|
|
self.attn_metadata_left = attn_metadata_left
|
|
self.attn_metadata_right = attn_metadata_right
|
|
self.num_input_tokens_left = num_input_tokens_left
|
|
self.num_input_tokens_right = num_input_tokens_right
|
|
self.input_ids_left = input_ids_left
|
|
self.input_ids_right = input_ids_right
|
|
self.positions_left = positions_left
|
|
self.positions_right = positions_right
|
|
self.num_tokens_across_dp = num_tokens_across_dp
|
|
self.intermediate_tensors = intermediate_tensors
|
|
self.inputs_embeds = inputs_embeds
|
|
|
|
self.model_input_left_queue.put(None)
|
|
self.model_input_right_queue.put(None)
|
|
|
|
def get_model_output(self):
|
|
states_left = self.states_left_queue.get()
|
|
states_right = self.states_right_queue.get()
|
|
return states_left, states_right
|
|
|
|
tbo_obj_v1 = None
|
|
|
|
def is_enable_tbo_v1():
|
|
global tbo_obj_v1
|
|
return tbo_obj_v1 != None
|
|
|
|
def init_two_batch_overlap():
|
|
global tbo_obj_v1
|
|
if tbo_obj_v1 == None:
|
|
tbo_obj_v1 = TwoBatchOverlap()
|
|
tbo_obj_v1.init_tbo_thread()
|
|
|
|
def tbo_all_reduce_v1(obj):
|
|
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
|
|
tid = threading.get_ident()
|
|
if tid == tbo_obj_v1.left_tid:
|
|
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
|
|
else:
|
|
event_c2t, event_t2c = tbo_obj_v1.event_right_c2t, tbo_obj_v1.event_right_t2c
|
|
event_c2t.record()
|
|
with torch.cuda.stream(all_reduce_stream):
|
|
all_reduce_stream.wait_event(event_c2t)
|
|
output = tensor_model_parallel_all_reduce(obj)
|
|
event_t2c.record()
|
|
tbo_obj_v1.tbo_thread_synchronize(tid)
|
|
tbo_step_stream.wait_event(event_t2c)
|
|
return output
|
|
return tensor_model_parallel_all_reduce(obj)
|
|
|
|
def merge_model_output(states_left, states_right):
|
|
if isinstance(states_left, IntermediateTensors):
|
|
output_map = {}
|
|
for key in states_left.tensors:
|
|
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
|
|
output = IntermediateTensors(output_map)
|
|
else:
|
|
output = torch.concat([states_left, states_right], dim=0)
|
|
return output
|
|
|
|
def tbo_model_executable_v1(
|
|
model_runner,
|
|
attn_metadata_left,
|
|
attn_metadata_right,
|
|
num_input_tokens_left,
|
|
num_input_tokens_right,
|
|
num_tokens_across_dp,
|
|
input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds
|
|
):
|
|
|
|
init_two_batch_overlap()
|
|
tbo_obj_v1.tbo_running = True
|
|
tbo_obj_v1.left_first = True
|
|
tbo_obj_v1.step_event.record()
|
|
current_stream = torch.cuda.current_stream()
|
|
with torch.cuda.stream(tbo_step_stream):
|
|
tbo_step_stream.wait_event(tbo_obj_v1.step_event)
|
|
tokens_split = [num_input_tokens_left, num_input_tokens_right]
|
|
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0)
|
|
positions_left, positions_right = torch.split(positions, tokens_split, dim=0)
|
|
tbo_obj_v1.set_model_input(model_runner,
|
|
attn_metadata_left,
|
|
attn_metadata_right,
|
|
num_input_tokens_left,
|
|
num_input_tokens_right,
|
|
input_ids_left,
|
|
input_ids_right,
|
|
positions_left,
|
|
positions_right,
|
|
num_tokens_across_dp,
|
|
intermediate_tensors,
|
|
inputs_embeds)
|
|
model_output_left, model_output_right = tbo_obj_v1.get_model_output()
|
|
|
|
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
|
|
tbo_obj_v1.tbo_running = False
|
|
tbo_obj_v1.step_event.record()
|
|
tbo_obj_v1.finish_thread()
|
|
current_stream.wait_event(tbo_obj_v1.step_event)
|
|
|
|
return hidden_or_intermediate_states |