init src 0.9.2
This commit is contained in:
35
vllm/two_batch_overlap/forward_context.py
Normal file
35
vllm/two_batch_overlap/forward_context.py
Normal file
@@ -0,0 +1,35 @@
|
||||
|
||||
|
||||
|
||||
import threading
|
||||
|
||||
_forward_context_left = None
|
||||
_forward_context_right = None
|
||||
|
||||
_left_tid = 0
|
||||
_right_tid = 0
|
||||
|
||||
def init_tbo_forward_context(left_flag, tid):
|
||||
global _left_tid
|
||||
global _right_tid
|
||||
if left_flag:
|
||||
_left_tid = tid
|
||||
else:
|
||||
_right_tid = tid
|
||||
|
||||
def set_tbo_forward_context(_forward_context):
|
||||
global _forward_context_left
|
||||
global _forward_context_right
|
||||
tid = threading.get_ident()
|
||||
if tid == _left_tid:
|
||||
_forward_context_left = _forward_context
|
||||
else:
|
||||
_forward_context_right = _forward_context
|
||||
|
||||
|
||||
def get_tbo_forward_context():
|
||||
tid = threading.get_ident()
|
||||
if tid == _left_tid:
|
||||
return _forward_context_left
|
||||
else:
|
||||
return _forward_context_right
|
||||
399
vllm/two_batch_overlap/model_input_split.py
Normal file
399
vllm/two_batch_overlap/model_input_split.py
Normal file
@@ -0,0 +1,399 @@
|
||||
import torch
|
||||
from vllm.attention.backends.flashmla import FlashMLAMetadata
|
||||
from vllm.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.utils import async_tensor_h2d
|
||||
|
||||
def cumsum(lst):
|
||||
cum_lst = [0]
|
||||
sum = 0
|
||||
for i in range(0, len(lst)):
|
||||
sum = sum + lst[i]
|
||||
cum_lst.append(sum)
|
||||
return cum_lst
|
||||
|
||||
def is_supported_attention_metadata(atten_metadata):
|
||||
return isinstance(atten_metadata, ROCmFlashAttentionMetadata) or \
|
||||
isinstance(atten_metadata, FlashMLAMetadata) or \
|
||||
isinstance(atten_metadata, MLACommonMetadata)
|
||||
|
||||
def split_model_input(model_input, self_device, batch_size_left, batch_size_right):
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
|
||||
batch_size_split = [batch_size_left, batch_size_right]
|
||||
split_input_tokens = torch.split(model_input.input_tokens, query_tokens_split, dim=0)
|
||||
split_input_positions = torch.split(model_input.input_positions, query_tokens_split, dim=0)
|
||||
seq_lens_left = model_input.attn_metadata.seq_lens[0:batch_size_left]
|
||||
seq_lens_right = model_input.attn_metadata.seq_lens[batch_size_left:]
|
||||
query_lens_left = model_input.query_lens[0:batch_size_left]
|
||||
query_lens_right = model_input.query_lens[batch_size_left:]
|
||||
split_seq_lens_tensor = torch.split(model_input.attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
|
||||
split_block_tables = torch.split(model_input.attn_metadata.block_tables, batch_size_split, dim=0)
|
||||
num_prefills_left = 0
|
||||
num_prefills_right = 0
|
||||
num_prefill_tokens_left = 0
|
||||
num_prefill_tokens_right = 0
|
||||
num_decode_tokens_left = 0
|
||||
num_decode_tokens_right = 0
|
||||
max_prefill_seq_len_left = 0
|
||||
max_prefill_seq_len_right = 0
|
||||
max_decode_seq_len_left = 0
|
||||
max_decode_seq_len_right = 0
|
||||
max_decode_query_len_left = None
|
||||
max_decode_query_len_right = None
|
||||
encoder_seq_lens_left = None
|
||||
encoder_seq_lens_right = None
|
||||
encoder_seq_lens_tensor_left = None
|
||||
encoder_seq_lens_tensor_right = None
|
||||
max_encoder_seq_len_left = None
|
||||
max_encoder_seq_len_right = None
|
||||
num_encoder_tokens_left = None
|
||||
num_encoder_tokens_right = None
|
||||
cross_slot_mapping_left = None
|
||||
cross_slot_mapping_right = None
|
||||
cross_block_tables_left = None
|
||||
cross_block_tables_right = None
|
||||
if model_input.is_prompt:
|
||||
num_prefills_left = batch_size_left
|
||||
num_prefills_right = batch_size_right
|
||||
num_prefill_tokens_left = sum(model_input.query_lens[0:batch_size_left])
|
||||
num_prefill_tokens_right = sum(model_input.query_lens[batch_size_left:])
|
||||
max_prefill_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
|
||||
max_prefill_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
|
||||
else:
|
||||
num_decode_tokens_left = batch_size_left
|
||||
num_decode_tokens_right = batch_size_right
|
||||
max_decode_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
|
||||
max_decode_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
|
||||
split_slot_mapping = torch.split(model_input.attn_metadata.slot_mapping, query_tokens_split, dim=0)
|
||||
max_query_len_left = max(model_input.query_lens[0:batch_size_left])
|
||||
max_query_len_right = max(model_input.query_lens[batch_size_left:])
|
||||
zero_tensor = torch.tensor([0], device=self_device, dtype=torch.int32)
|
||||
query_start_loc_left_list = cumsum(query_lens_left)
|
||||
query_start_loc_right_list = cumsum(query_lens_right)
|
||||
query_start_loc_left = async_tensor_h2d(query_start_loc_left_list, torch.int32,
|
||||
self_device,
|
||||
True)
|
||||
query_start_loc_right = async_tensor_h2d(query_start_loc_right_list, torch.int32,
|
||||
self_device,
|
||||
True)
|
||||
seq_start_loc_left = torch.cat((zero_tensor, split_seq_lens_tensor[0].cumsum(dim=0)), dim=0).to(torch.int32)
|
||||
seq_start_loc_right = torch.cat((zero_tensor, split_seq_lens_tensor[1].cumsum(dim=0)), dim=0).to(torch.int32)
|
||||
|
||||
split_context_lens_tensor = torch.split(model_input.attn_metadata.context_lens_tensor, batch_size_split, dim=0)
|
||||
request_ids_to_seq_ids_left = {}
|
||||
request_ids_to_seq_ids_right = {}
|
||||
counter = 0
|
||||
for key, value in model_input.request_ids_to_seq_ids.items():
|
||||
if counter < batch_size_left:
|
||||
request_ids_to_seq_ids_left[key] = value
|
||||
else:
|
||||
request_ids_to_seq_ids_right[key] = value
|
||||
counter += 1
|
||||
|
||||
previous_hidden_states_left = None
|
||||
previous_hidden_states_right = None
|
||||
if model_input.previous_hidden_states != None:
|
||||
split_previous_hidden_states = torch.split(model_input.previous_hidden_states, query_tokens_split, dim=0)
|
||||
previous_hidden_states_left = split_previous_hidden_states[0]
|
||||
previous_hidden_states_right = split_previous_hidden_states[1]
|
||||
|
||||
if isinstance(model_input.attn_metadata, MLACommonMetadata):
|
||||
attn_metadata_left = MLACommonMetadata(
|
||||
num_prefills = num_prefills_left,
|
||||
num_prefill_tokens = num_prefill_tokens_left,
|
||||
num_decode_tokens = num_decode_tokens_left,
|
||||
slot_mapping = split_slot_mapping[0],
|
||||
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
|
||||
input_positions = split_input_positions[0],
|
||||
seq_lens = seq_lens_left,
|
||||
seq_lens_tensor = split_seq_lens_tensor[0],
|
||||
max_prefill_seq_len = max_prefill_seq_len_left,
|
||||
max_decode_seq_len = max_decode_seq_len_left,
|
||||
context_lens_tensor = split_context_lens_tensor[0],
|
||||
block_tables = split_block_tables[0],
|
||||
max_query_len = max_query_len_left,
|
||||
max_decode_query_len = max_decode_query_len_left,
|
||||
query_start_loc = query_start_loc_left,
|
||||
seq_start_loc = seq_start_loc_left,
|
||||
_cached_prefill_metadata = None,
|
||||
_cached_decode_metadata = None,
|
||||
head_dim = model_input.attn_metadata.head_dim,
|
||||
is_profile_run = model_input.attn_metadata.is_profile_run,
|
||||
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
|
||||
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
|
||||
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
|
||||
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
|
||||
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
|
||||
)
|
||||
attn_metadata_right = MLACommonMetadata(
|
||||
num_prefills = num_prefills_right,
|
||||
num_prefill_tokens = num_prefill_tokens_right,
|
||||
num_decode_tokens = num_decode_tokens_right,
|
||||
slot_mapping = split_slot_mapping[1],
|
||||
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
|
||||
input_positions = split_input_positions[1],
|
||||
seq_lens = seq_lens_right,
|
||||
seq_lens_tensor = split_seq_lens_tensor[1],
|
||||
max_prefill_seq_len = max_prefill_seq_len_right,
|
||||
max_decode_seq_len = max_decode_seq_len_right,
|
||||
context_lens_tensor = split_context_lens_tensor[1],
|
||||
block_tables = split_block_tables[1],
|
||||
max_query_len = max_query_len_right,
|
||||
max_decode_query_len = max_decode_query_len_right,
|
||||
query_start_loc = query_start_loc_right,
|
||||
seq_start_loc = seq_start_loc_right,
|
||||
_cached_prefill_metadata = None,
|
||||
_cached_decode_metadata = None,
|
||||
head_dim = model_input.attn_metadata.head_dim,
|
||||
is_profile_run = model_input.attn_metadata.is_profile_run,
|
||||
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
|
||||
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
|
||||
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
|
||||
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
|
||||
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
|
||||
)
|
||||
|
||||
if isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata):
|
||||
block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left]
|
||||
block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:]
|
||||
attn_metadata_left = ROCmFlashAttentionMetadata(
|
||||
seq_lens_tensor = split_seq_lens_tensor[0],
|
||||
max_decode_seq_len = max_decode_seq_len_left,
|
||||
block_tables = split_block_tables[0],
|
||||
num_prefills = num_prefills_left,
|
||||
num_prefill_tokens = num_prefill_tokens_left,
|
||||
num_decode_tokens = num_decode_tokens_left,
|
||||
slot_mapping = split_slot_mapping[0],
|
||||
multi_modal_placeholder_index_maps = {},
|
||||
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
|
||||
seq_lens = seq_lens_left,
|
||||
max_prefill_seq_len = max_prefill_seq_len_left,
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
|
||||
max_query_len = max_query_len_left,
|
||||
query_start_loc = query_start_loc_left,
|
||||
seq_start_loc = seq_start_loc_left,
|
||||
context_lens_tensor = split_context_lens_tensor[0],
|
||||
max_decode_query_len = max_decode_query_len_left,
|
||||
_cached_prefill_metadata = None,
|
||||
_cached_decode_metadata = None,
|
||||
tree_attention_masks_tensor = None,
|
||||
block_tables_list = block_tables_list_left,
|
||||
encoder_seq_lens = encoder_seq_lens_left,
|
||||
encoder_seq_lens_tensor = encoder_seq_lens_tensor_left,
|
||||
max_encoder_seq_len = max_encoder_seq_len_left,
|
||||
num_encoder_tokens = num_encoder_tokens_left,
|
||||
cross_slot_mapping = cross_slot_mapping_left,
|
||||
cross_block_tables = cross_block_tables_left,
|
||||
)
|
||||
attn_metadata_right = ROCmFlashAttentionMetadata(
|
||||
seq_lens_tensor = split_seq_lens_tensor[1],
|
||||
max_decode_seq_len = max_decode_seq_len_right,
|
||||
block_tables = split_block_tables[1],
|
||||
num_prefills = num_prefills_right,
|
||||
num_prefill_tokens = num_prefill_tokens_right,
|
||||
num_decode_tokens = num_decode_tokens_right,
|
||||
slot_mapping = split_slot_mapping[1],
|
||||
multi_modal_placeholder_index_maps = {},
|
||||
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
|
||||
seq_lens = seq_lens_right,
|
||||
max_prefill_seq_len = max_prefill_seq_len_right,
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
|
||||
max_query_len = max_query_len_right,
|
||||
query_start_loc = query_start_loc_right,
|
||||
seq_start_loc = seq_start_loc_right,
|
||||
context_lens_tensor = split_context_lens_tensor[1],
|
||||
max_decode_query_len = max_decode_query_len_right,
|
||||
_cached_prefill_metadata = None,
|
||||
_cached_decode_metadata = None,
|
||||
tree_attention_masks_tensor = None,
|
||||
block_tables_list = block_tables_list_right,
|
||||
encoder_seq_lens = encoder_seq_lens_right,
|
||||
encoder_seq_lens_tensor = encoder_seq_lens_tensor_right,
|
||||
max_encoder_seq_len = max_encoder_seq_len_right,
|
||||
num_encoder_tokens = num_encoder_tokens_right,
|
||||
cross_slot_mapping = cross_slot_mapping_right,
|
||||
cross_block_tables = cross_block_tables_right,
|
||||
)
|
||||
|
||||
if isinstance(model_input.attn_metadata, FlashMLAMetadata):
|
||||
attn_metadata_left = FlashMLAMetadata(
|
||||
num_prefills = num_prefills_left,
|
||||
num_prefill_tokens = num_prefill_tokens_left,
|
||||
num_decode_tokens = num_decode_tokens_left,
|
||||
slot_mapping = split_slot_mapping[0],
|
||||
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
|
||||
input_positions = split_input_positions[0],
|
||||
seq_lens = seq_lens_left,
|
||||
seq_lens_tensor = split_seq_lens_tensor[0],
|
||||
max_prefill_seq_len = max_prefill_seq_len_left,
|
||||
max_decode_seq_len = max_decode_seq_len_left,
|
||||
context_lens_tensor = split_context_lens_tensor[0],
|
||||
block_tables = split_block_tables[0],
|
||||
max_query_len = max_query_len_left,
|
||||
max_decode_query_len = max_decode_query_len_left,
|
||||
query_start_loc = query_start_loc_left,
|
||||
seq_start_loc = seq_start_loc_left,
|
||||
_cached_prefill_metadata = None,
|
||||
_cached_decode_metadata = None,
|
||||
head_dim = model_input.attn_metadata.head_dim,
|
||||
is_profile_run = model_input.attn_metadata.is_profile_run,
|
||||
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
|
||||
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
|
||||
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
|
||||
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
|
||||
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
|
||||
decode_tile_scheduler_metadata=model_input.attn_metadata.decode_tile_scheduler_metadata,
|
||||
decode_num_splits=model_input.attn_metadata.decode_num_splits
|
||||
)
|
||||
attn_metadata_right = FlashMLAMetadata(
|
||||
num_prefills = num_prefills_right,
|
||||
num_prefill_tokens = num_prefill_tokens_right,
|
||||
num_decode_tokens = num_decode_tokens_right,
|
||||
slot_mapping = split_slot_mapping[1],
|
||||
multi_modal_placeholder_index_maps = model_input.attn_metadata.multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
|
||||
input_positions = split_input_positions[1],
|
||||
seq_lens = seq_lens_right,
|
||||
seq_lens_tensor = split_seq_lens_tensor[1],
|
||||
max_prefill_seq_len = max_prefill_seq_len_right,
|
||||
max_decode_seq_len = max_decode_seq_len_right,
|
||||
context_lens_tensor = split_context_lens_tensor[1],
|
||||
block_tables = split_block_tables[1],
|
||||
max_query_len = max_query_len_right,
|
||||
max_decode_query_len = max_decode_query_len_right,
|
||||
query_start_loc = query_start_loc_right,
|
||||
seq_start_loc = seq_start_loc_right,
|
||||
_cached_prefill_metadata = None,
|
||||
_cached_decode_metadata = None,
|
||||
head_dim = model_input.attn_metadata.head_dim,
|
||||
is_profile_run = model_input.attn_metadata.is_profile_run,
|
||||
context_chunk_cu_seq_lens=model_input.attn_metadata.context_chunk_cu_seq_lens,
|
||||
context_chunk_starts=model_input.attn_metadata.context_chunk_starts,
|
||||
context_chunk_seq_tot=model_input.attn_metadata.context_chunk_seq_tot,
|
||||
context_chunk_max_seq_lens=model_input.attn_metadata.context_chunk_max_seq_lens,
|
||||
context_chunk_workspace=model_input.attn_metadata.context_chunk_workspace,
|
||||
decode_tile_scheduler_metadata=model_input.attn_metadata.decode_tile_scheduler_metadata,
|
||||
decode_num_splits=model_input.attn_metadata.decode_num_splits
|
||||
)
|
||||
|
||||
model_input_left = ModelInputForGPUWithSamplingMetadata(
|
||||
input_tokens=split_input_tokens[0],
|
||||
input_positions=split_input_positions[0],
|
||||
token_types=None,
|
||||
seq_lens=seq_lens_left,
|
||||
query_lens=query_lens_left,
|
||||
lora_mapping=model_input.lora_mapping,
|
||||
lora_requests=model_input.lora_requests,
|
||||
attn_metadata=attn_metadata_left,
|
||||
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
|
||||
prompt_adapter_requests=model_input.prompt_adapter_requests,
|
||||
multi_modal_kwargs=model_input.multi_modal_kwargs,
|
||||
request_ids_to_seq_ids=request_ids_to_seq_ids_left,
|
||||
finished_requests_ids=model_input.finished_requests_ids,
|
||||
virtual_engine=model_input.virtual_engine,
|
||||
async_callback=model_input.async_callback,
|
||||
scheduler_outputs=model_input.scheduler_outputs,
|
||||
previous_hidden_states=previous_hidden_states_left,
|
||||
sampling_metadata=None, #TBO does not require sampling_stetadata
|
||||
is_prompt=model_input.is_prompt,
|
||||
)
|
||||
model_input_right = ModelInputForGPUWithSamplingMetadata(
|
||||
input_tokens=split_input_tokens[1],
|
||||
input_positions=split_input_positions[1],
|
||||
token_types=None,
|
||||
seq_lens=seq_lens_right,
|
||||
query_lens=query_lens_right,
|
||||
lora_mapping=model_input.lora_mapping,
|
||||
lora_requests=model_input.lora_requests,
|
||||
attn_metadata=attn_metadata_right,
|
||||
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
|
||||
prompt_adapter_requests=model_input.prompt_adapter_requests,
|
||||
multi_modal_kwargs=model_input.multi_modal_kwargs,
|
||||
request_ids_to_seq_ids=request_ids_to_seq_ids_right,
|
||||
finished_requests_ids=model_input.finished_requests_ids,
|
||||
virtual_engine=model_input.virtual_engine,
|
||||
async_callback=model_input.async_callback,
|
||||
scheduler_outputs=model_input.scheduler_outputs,
|
||||
previous_hidden_states=previous_hidden_states_right,
|
||||
sampling_metadata=None, #TBO does not require sampling_stetadata
|
||||
is_prompt=model_input.is_prompt,
|
||||
)
|
||||
return model_input_left, model_input_right
|
||||
|
||||
def split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right):
|
||||
batch_size_split = [batch_size_left, batch_size_right]
|
||||
split_seq_lens_tensor = torch.split(attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
|
||||
split_block_tables = torch.split(attn_metadata.block_tables, batch_size_split, dim=0)
|
||||
split_slot_mapping = torch.split(attn_metadata.slot_mapping, batch_size_split, dim=0)
|
||||
if isinstance(attn_metadata, ROCmFlashAttentionMetadata):
|
||||
attn_metadata_left = ROCmFlashAttentionMetadata(
|
||||
seq_lens_tensor = split_seq_lens_tensor[0],
|
||||
max_decode_seq_len = attn_metadata.max_decode_seq_len,
|
||||
block_tables = split_block_tables[0],
|
||||
num_prefills = 0,
|
||||
num_prefill_tokens = 0,
|
||||
num_decode_tokens = batch_size_left,
|
||||
slot_mapping = split_slot_mapping[0],
|
||||
multi_modal_placeholder_index_maps = attn_metadata.multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation = attn_metadata.enable_kv_scales_calculation,
|
||||
seq_lens = None,
|
||||
max_prefill_seq_len = 0,
|
||||
use_cuda_graph = attn_metadata.use_cuda_graph,
|
||||
max_query_len = 1,
|
||||
query_start_loc = None,
|
||||
seq_start_loc = None,
|
||||
context_lens_tensor = None,
|
||||
max_decode_query_len = 1,
|
||||
_cached_prefill_metadata = None,
|
||||
_cached_decode_metadata = None,
|
||||
tree_attention_masks_tensor = None,
|
||||
block_tables_list = None,
|
||||
encoder_seq_lens = None,
|
||||
encoder_seq_lens_tensor = None,
|
||||
max_encoder_seq_len = None,
|
||||
num_encoder_tokens = None,
|
||||
cross_slot_mapping = None,
|
||||
cross_block_tables = None,
|
||||
)
|
||||
attn_metadata_right = ROCmFlashAttentionMetadata(
|
||||
seq_lens_tensor = split_seq_lens_tensor[1],
|
||||
max_decode_seq_len = attn_metadata.max_decode_seq_len,
|
||||
block_tables = split_block_tables[1],
|
||||
num_prefills = 0,
|
||||
num_prefill_tokens = 0,
|
||||
num_decode_tokens = batch_size_right,
|
||||
slot_mapping = split_slot_mapping[1],
|
||||
multi_modal_placeholder_index_maps = attn_metadata.multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation = attn_metadata.enable_kv_scales_calculation,
|
||||
seq_lens = None,
|
||||
max_prefill_seq_len = 0,
|
||||
use_cuda_graph = attn_metadata.use_cuda_graph,
|
||||
max_query_len = 1,
|
||||
query_start_loc = None,
|
||||
seq_start_loc = None,
|
||||
context_lens_tensor = None,
|
||||
max_decode_query_len = 1,
|
||||
_cached_prefill_metadata = None,
|
||||
_cached_decode_metadata = None,
|
||||
tree_attention_masks_tensor = None,
|
||||
block_tables_list = None,
|
||||
encoder_seq_lens = None,
|
||||
encoder_seq_lens_tensor = None,
|
||||
max_encoder_seq_len = None,
|
||||
num_encoder_tokens = None,
|
||||
cross_slot_mapping = None,
|
||||
cross_block_tables = None,
|
||||
)
|
||||
else:
|
||||
print("tbo:not surpport in cuda-graph ", type(attn_metadata))
|
||||
return attn_metadata_left, attn_metadata_right
|
||||
|
||||
481
vllm/two_batch_overlap/two_batch_overlap.py
Normal file
481
vllm/two_batch_overlap/two_batch_overlap.py
Normal file
@@ -0,0 +1,481 @@
|
||||
import gc
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
|
||||
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_capture_attention_metadata, split_model_input
|
||||
from vllm.logger import init_logger
|
||||
from vllm.profiler.prof import profile
|
||||
from vllm import envs
|
||||
from vllm.utils import weak_ref_tensor
|
||||
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import is_enable_tbo_v1, tbo_all_reduce_v1
|
||||
|
||||
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
tbo_step_stream = None
|
||||
all_reduce_stream = None
|
||||
|
||||
class TwoBatchOverlap():
|
||||
def __init__(self):
|
||||
global tbo_step_stream
|
||||
global all_reduce_stream
|
||||
self.model_input_left_queue = queue.Queue()
|
||||
self.model_input_right_queue = queue.Queue()
|
||||
self.states_left_queue = queue.Queue()
|
||||
self.states_right_queue = queue.Queue()
|
||||
self.left_thread = None
|
||||
self.right_thread = None
|
||||
self.left_tid = 0
|
||||
self.right_tid = 0
|
||||
self.sem_left = threading.Semaphore(0)
|
||||
self.sem_right = threading.Semaphore(0)
|
||||
self.left_first = False
|
||||
self.tbo_running = False
|
||||
self.tbo_in_capture = False
|
||||
if tbo_step_stream == None:
|
||||
tbo_step_stream = torch.cuda.Stream()
|
||||
all_reduce_stream = torch.cuda.Stream()
|
||||
self.step_event = torch.cuda.Event(enable_timing=False)
|
||||
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
|
||||
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
|
||||
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
|
||||
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
|
||||
|
||||
def init_tbo_thread(self):
|
||||
self.model_input_left_queue.empty()
|
||||
self.model_input_right_queue.empty()
|
||||
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
|
||||
self.left_thread.start()
|
||||
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
|
||||
self.right_thread.start()
|
||||
if get_tp_group().rank == 0:
|
||||
logger.info('tbo:two batch overlap start')
|
||||
|
||||
def finish_thread(self):
|
||||
self.left_thread.join()
|
||||
self.left_thread = None
|
||||
self.right_thread.join()
|
||||
self.right_thread = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def thread_two_batch_overlap(self, queue):
|
||||
is_left_thread = False
|
||||
tid = threading.get_ident()
|
||||
if queue == self.model_input_left_queue:
|
||||
self.left_tid = tid
|
||||
is_left_thread = True
|
||||
init_tbo_forward_context(True, self.left_tid)
|
||||
else:
|
||||
self.right_tid = tid
|
||||
init_tbo_forward_context(False, self.right_tid)
|
||||
with torch.cuda.stream(tbo_step_stream):
|
||||
model_input = queue.get()
|
||||
profile.ProfRangePush('start')
|
||||
self.tbo_thread_synchronize(tid)
|
||||
model_kwargs = None
|
||||
intermediate_tensors = None
|
||||
if is_left_thread:
|
||||
model_kwargs = self.model_kwargs_left
|
||||
intermediate_tensors = self.intermediate_tensors_left
|
||||
else:
|
||||
model_kwargs = self.model_kwargs_right
|
||||
intermediate_tensors = self.intermediate_tensors_right
|
||||
hidden_or_intermediate_states = None
|
||||
if self.tbo_in_capture:
|
||||
if is_left_thread:
|
||||
attn_metadata = self.attn_metadata_left
|
||||
input_tokens = self.input_tokens_left
|
||||
input_positions = self.split_input_positions[0]
|
||||
else:
|
||||
attn_metadata = self.attn_metadata_right
|
||||
input_tokens = self.input_tokens_right
|
||||
input_positions = self.split_input_positions[1]
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config, self.virtual_engine):
|
||||
hidden_or_intermediate_states = self.model_executable(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
|
||||
device=self.self_device),
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_input != None:
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config, self.virtual_engine):
|
||||
hidden_or_intermediate_states = self.model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
|
||||
device=self.self_device),
|
||||
**self.seqlen_agnostic_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
if is_left_thread:
|
||||
self.sem_right.release()
|
||||
self.states_left_queue.put(hidden_or_intermediate_states)
|
||||
else:
|
||||
self.states_right_queue.put(hidden_or_intermediate_states)
|
||||
profile.ProfRangePop()
|
||||
|
||||
def tbo_thread_synchronize(self, tid):
|
||||
if tid == self.left_tid:
|
||||
if not self.left_first:
|
||||
self.sem_right.release()
|
||||
self.left_first = False
|
||||
profile.ProfRangePop()
|
||||
self.sem_left.acquire()
|
||||
profile.ProfRangePush('left')
|
||||
return self.event_left_c2t, self.event_left_t2c
|
||||
else:
|
||||
self.sem_left.release()
|
||||
profile.ProfRangePop()
|
||||
self.sem_right.acquire()
|
||||
profile.ProfRangePush('right')
|
||||
return self.event_right_c2t, self.event_right_t2c
|
||||
|
||||
def set_model_input(self,
|
||||
model_input_left,
|
||||
model_input_right,
|
||||
vllm_config,
|
||||
virtual_engine,
|
||||
model_executable,
|
||||
intermediate_tensors_left,
|
||||
intermediate_tensors_right,
|
||||
multi_modal_kwargs,
|
||||
self_device,
|
||||
seqlen_agnostic_kwargs,
|
||||
model_kwargs_left,
|
||||
model_kwargs_right):
|
||||
self.vllm_config = vllm_config
|
||||
self.virtual_engine = virtual_engine
|
||||
self.model_executable = model_executable
|
||||
self.intermediate_tensors_left = intermediate_tensors_left
|
||||
self.intermediate_tensors_right = intermediate_tensors_right
|
||||
self.multi_modal_kwargs = multi_modal_kwargs
|
||||
self.self_device = self_device
|
||||
self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
|
||||
self.model_kwargs_left = model_kwargs_left
|
||||
self.model_kwargs_right = model_kwargs_right
|
||||
self.model_input_left_queue.put(model_input_left)
|
||||
self.model_input_right_queue.put(model_input_right)
|
||||
|
||||
def set_capture_model_input(self,
|
||||
input_tokens_left,
|
||||
input_tokens_right,
|
||||
split_input_positions,
|
||||
vllm_config,
|
||||
virtual_engine,
|
||||
runner_model,
|
||||
runner_device,
|
||||
intermediate_tensors_left,
|
||||
intermediate_tensors_right,
|
||||
model_kwargs_left,
|
||||
model_kwargs_right,
|
||||
attn_metadata_left,
|
||||
attn_metadata_right):
|
||||
self.input_tokens_left = input_tokens_left
|
||||
self.input_tokens_right = input_tokens_right
|
||||
self.split_input_positions = split_input_positions
|
||||
self.vllm_config = vllm_config
|
||||
self.virtual_engine = virtual_engine
|
||||
self.model_executable = runner_model
|
||||
self.self_device = runner_device
|
||||
self.intermediate_tensors_left = intermediate_tensors_left
|
||||
self.intermediate_tensors_right = intermediate_tensors_right
|
||||
self.model_kwargs_left = model_kwargs_left
|
||||
self.model_kwargs_right = model_kwargs_right
|
||||
self.attn_metadata_left = attn_metadata_left
|
||||
self.attn_metadata_right = attn_metadata_right
|
||||
self.model_input_left_queue.put(None)
|
||||
self.model_input_right_queue.put(None)
|
||||
|
||||
|
||||
def get_model_output(self):
|
||||
states_left = self.states_left_queue.get()
|
||||
states_right = self.states_right_queue.get()
|
||||
return states_left, states_right
|
||||
|
||||
tbo_obj = None
|
||||
|
||||
def init_two_batch_overlap():
|
||||
global tbo_obj
|
||||
if tbo_obj == None:
|
||||
tbo_obj = TwoBatchOverlap()
|
||||
tbo_obj.init_tbo_thread()
|
||||
|
||||
def tbo_all_reduce(obj):
|
||||
if is_enable_tbo_v1():
|
||||
return tbo_all_reduce_v1(obj)
|
||||
if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
|
||||
tid = threading.get_ident()
|
||||
if not tbo_one_stream:
|
||||
if tid == tbo_obj.left_tid:
|
||||
event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
|
||||
else:
|
||||
event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
|
||||
event_c2t.record()
|
||||
with torch.cuda.stream(all_reduce_stream):
|
||||
all_reduce_stream.wait_event(event_c2t)
|
||||
output = tensor_model_parallel_all_reduce(obj)
|
||||
event_t2c.record()
|
||||
tbo_obj.tbo_thread_synchronize(tid)
|
||||
tbo_step_stream.wait_event(event_t2c)
|
||||
else:
|
||||
output = tensor_model_parallel_all_reduce(obj)
|
||||
tbo_obj.tbo_thread_synchronize(tid)
|
||||
return output
|
||||
return tensor_model_parallel_all_reduce(obj)
|
||||
|
||||
def merge_model_output(states_left, states_right):
|
||||
if isinstance(states_left, IntermediateTensors):
|
||||
output_map = {}
|
||||
for key in states_left.tensors:
|
||||
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
|
||||
output = IntermediateTensors(output_map)
|
||||
else:
|
||||
output = torch.concat([states_left, states_right], dim=0)
|
||||
return output
|
||||
|
||||
def tbo_model_executable(
|
||||
model_input,
|
||||
vllm_config,
|
||||
virtual_engine,
|
||||
model_executable,
|
||||
intermediate_tensors,
|
||||
multi_modal_kwargs,
|
||||
self_device,
|
||||
seqlen_agnostic_kwargs,
|
||||
model_kwargs,
|
||||
):
|
||||
is_support = is_supported_attention_metadata(model_input.attn_metadata)
|
||||
if not is_support:
|
||||
logger.info("tbo:not surpport yet ", type(model_input.attn_metadata))
|
||||
batch_size = len(model_input.attn_metadata.seq_lens)
|
||||
is_decode_tbo_invalid = not model_input.is_prompt and (
|
||||
envs.VLLM_TBO_DECODE_BS < 2 or
|
||||
batch_size < envs.VLLM_TBO_DECODE_BS or
|
||||
model_input.attn_metadata.use_cuda_graph)
|
||||
if batch_size == 1 or \
|
||||
is_decode_tbo_invalid or \
|
||||
not is_support:
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
vllm_config, virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=model_input.inputs_embeds,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self_device),
|
||||
**seqlen_agnostic_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_or_intermediate_states
|
||||
profile.ProfRangePush('tbo_model_executable')
|
||||
init_two_batch_overlap()
|
||||
tbo_obj.tbo_running = True
|
||||
tbo_obj.left_first = True
|
||||
batch_size_left = int(batch_size / 2)
|
||||
batch_size_right = batch_size_left
|
||||
if batch_size % 2 == 1:
|
||||
batch_size_right += 1
|
||||
|
||||
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
|
||||
|
||||
model_kwargs_left = model_kwargs.copy()
|
||||
model_kwargs_right = model_kwargs.copy()
|
||||
intermediate_tensors_left = None
|
||||
intermediate_tensors_right = None
|
||||
if "previous_hidden_states" in model_kwargs:
|
||||
previous_hidden_states = model_kwargs["previous_hidden_states"]
|
||||
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
|
||||
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
|
||||
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
|
||||
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
|
||||
if intermediate_tensors != None:
|
||||
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
|
||||
intermediate_tensors_left = {}
|
||||
intermediate_tensors_right = {}
|
||||
for key in intermediate_tensors.tensors:
|
||||
split_intermediate_tensors = torch.split(intermediate_tensors.tensors[key], query_tokens_split, dim=0)
|
||||
intermediate_tensors_left[key] = split_intermediate_tensors[0]
|
||||
intermediate_tensors_right[key] = split_intermediate_tensors[1]
|
||||
intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
|
||||
intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
|
||||
|
||||
tbo_obj.step_event.record()
|
||||
current_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(tbo_step_stream):
|
||||
tbo_step_stream.wait_event(tbo_obj.step_event)
|
||||
tbo_obj.set_model_input(model_input_left,
|
||||
model_input_right,
|
||||
vllm_config,
|
||||
virtual_engine,
|
||||
model_executable,
|
||||
intermediate_tensors_left,
|
||||
intermediate_tensors_right,
|
||||
multi_modal_kwargs,
|
||||
self_device,
|
||||
seqlen_agnostic_kwargs,
|
||||
model_kwargs_left,
|
||||
model_kwargs_right)
|
||||
|
||||
states_left, states_right = tbo_obj.get_model_output()
|
||||
|
||||
hidden_or_intermediate_states = merge_model_output(states_left, states_right)
|
||||
tbo_obj.tbo_running = False
|
||||
tbo_obj.step_event.record()
|
||||
tbo_obj.finish_thread()
|
||||
current_stream.wait_event(tbo_obj.step_event)
|
||||
profile.ProfRangePop()
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
def _run_once(vllm_config, virtual_engine,
|
||||
runner,
|
||||
self_device,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_inputs: Optional[IntermediateTensors],
|
||||
attn_metadata: AttentionMetadata,
|
||||
stream: torch.cuda.Stream,
|
||||
**kwargs):
|
||||
global tbo_step_stream
|
||||
stream_back = tbo_step_stream
|
||||
tbo_step_stream = stream
|
||||
init_two_batch_overlap()
|
||||
tbo_obj.left_first = True
|
||||
decode_batch_size = input_ids.shape[0]
|
||||
batch_size_left = int(decode_batch_size / 2)
|
||||
batch_size_right = decode_batch_size - batch_size_left
|
||||
query_tokens_split = [batch_size_left, batch_size_right]
|
||||
input_tokens_left, input_tokens_right = torch.split(input_ids, query_tokens_split, dim=0)
|
||||
split_input_positions = torch.split(positions, query_tokens_split, dim=0)
|
||||
model_kwargs_left = kwargs.copy()
|
||||
model_kwargs_right = kwargs.copy()
|
||||
intermediate_tensors_left = None
|
||||
intermediate_tensors_right = None
|
||||
if "previous_hidden_states" in kwargs:
|
||||
previous_hidden_states = kwargs["previous_hidden_states"]
|
||||
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
|
||||
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
|
||||
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
|
||||
if intermediate_inputs != None:
|
||||
query_tokens_split = [batch_size_left, batch_size_right]
|
||||
intermediate_tensors_left = {}
|
||||
intermediate_tensors_right = {}
|
||||
for key in intermediate_inputs.tensors:
|
||||
split_intermediate_tensors = torch.split(intermediate_inputs.tensors[key], query_tokens_split, dim=0)
|
||||
intermediate_tensors_left[key] = split_intermediate_tensors[0]
|
||||
intermediate_tensors_right[key] = split_intermediate_tensors[1]
|
||||
intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
|
||||
intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
|
||||
attn_metadata_left, attn_metadata_right = split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right)
|
||||
tbo_obj.tbo_running = True
|
||||
tbo_obj.tbo_in_capture = True
|
||||
tbo_obj.set_capture_model_input(input_tokens_left,
|
||||
input_tokens_right,
|
||||
split_input_positions,
|
||||
vllm_config,
|
||||
virtual_engine,
|
||||
runner.model,
|
||||
self_device,
|
||||
intermediate_tensors_left,
|
||||
intermediate_tensors_right,
|
||||
model_kwargs_left,
|
||||
model_kwargs_right,
|
||||
attn_metadata_left,
|
||||
attn_metadata_right)
|
||||
|
||||
states_left, states_right = tbo_obj.get_model_output()
|
||||
output_hidden_or_intermediate_states = merge_model_output(states_left, states_right)
|
||||
tbo_obj.tbo_in_capture = False
|
||||
tbo_obj.tbo_running = False
|
||||
tbo_obj.finish_thread()
|
||||
tbo_step_stream = stream_back
|
||||
return output_hidden_or_intermediate_states
|
||||
|
||||
def tbo_capture(vllm_config, virtual_engine, _NUM_WARMUP_ITERS,
|
||||
runner,
|
||||
self_device,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_inputs: Optional[IntermediateTensors],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
memory_pool: Optional[Tuple[int, int]],
|
||||
stream: torch.cuda.Stream,
|
||||
**kwargs):
|
||||
for i in range(_NUM_WARMUP_ITERS):
|
||||
_run_once(vllm_config,
|
||||
virtual_engine,
|
||||
runner,
|
||||
self_device,
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_inputs,
|
||||
attn_metadata,
|
||||
torch.cuda.current_stream(),
|
||||
**kwargs)
|
||||
torch.cuda.synchronize()
|
||||
runner._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(runner._graph, pool=memory_pool, stream=stream):
|
||||
output_hidden_or_intermediate_states = _run_once(vllm_config,
|
||||
virtual_engine,
|
||||
runner,
|
||||
self_device,
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_inputs,
|
||||
attn_metadata,
|
||||
torch.cuda.current_stream(),
|
||||
**kwargs)
|
||||
if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
|
||||
hidden_or_intermediate_states = weak_ref_tensor(
|
||||
output_hidden_or_intermediate_states)
|
||||
elif isinstance(output_hidden_or_intermediate_states,
|
||||
IntermediateTensors):
|
||||
hidden_or_intermediate_states = IntermediateTensors(
|
||||
tensors={
|
||||
key: weak_ref_tensor(value)
|
||||
for key, value in
|
||||
output_hidden_or_intermediate_states.tensors.items()
|
||||
})
|
||||
|
||||
del output_hidden_or_intermediate_states
|
||||
# make sure `output_hidden_or_intermediate_states` is deleted
|
||||
# in the graph's memory pool
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Save the input and output buffers.
|
||||
runner.input_buffers = {
|
||||
"input_ids":
|
||||
input_ids,
|
||||
"positions":
|
||||
positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
**runner.attn_state.get_graph_input_buffers(
|
||||
attn_metadata, runner._is_encoder_decoder_model),
|
||||
**kwargs,
|
||||
}
|
||||
if intermediate_inputs is not None:
|
||||
runner.input_buffers.update(intermediate_inputs.tensors)
|
||||
if get_pp_group().is_last_rank:
|
||||
runner.output_buffers = {
|
||||
"hidden_states": hidden_or_intermediate_states
|
||||
}
|
||||
else:
|
||||
runner.output_buffers = hidden_or_intermediate_states
|
||||
335
vllm/two_batch_overlap/v1/model_input_split_v1.py
Normal file
335
vllm/two_batch_overlap/v1/model_input_split_v1.py
Normal file
@@ -0,0 +1,335 @@
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm import envs
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1
|
||||
from vllm.utils import async_tensor_h2d
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
|
||||
class TBOModelInputSplit():
|
||||
def __init__(self):
|
||||
self.req_ids_left = []
|
||||
self.req_ids_right = []
|
||||
self.req_num_left = 0
|
||||
self.req_num_right = 0
|
||||
self.scheduler_output_left = None
|
||||
self.scheduler_output_right = None
|
||||
self.query_start_loc_right = None
|
||||
|
||||
input_split = TBOModelInputSplit()
|
||||
|
||||
def split_scheduler_output(runner, scheduler_output:SchedulerOutput):
|
||||
split_tokens = scheduler_output.total_num_scheduled_tokens // 2
|
||||
req_ids = runner.input_batch.req_ids
|
||||
tokens_counter = 0
|
||||
min_idx = -1
|
||||
min_counter = 0
|
||||
for i, id in enumerate(req_ids):
|
||||
tokens_counter += scheduler_output.num_scheduled_tokens[id]
|
||||
diff = abs(tokens_counter - split_tokens)
|
||||
if min_idx == -1 or diff < min_counter:
|
||||
min_idx = i
|
||||
min_counter = diff
|
||||
if tokens_counter > split_tokens or diff == 0:
|
||||
break
|
||||
input_split.req_num_left = min_idx + 1
|
||||
if input_split.req_num_left == len(req_ids):
|
||||
input_split.req_num_left = input_split.req_num_left - 1
|
||||
input_split.req_ids_left = req_ids[:input_split.req_num_left]
|
||||
input_split.req_ids_right = req_ids[input_split.req_num_left:]
|
||||
input_split.req_num_right = len(req_ids) - input_split.req_num_left
|
||||
new_req_data_left = []
|
||||
new_req_data_right = []
|
||||
cached_reqs_left = []
|
||||
cached_reqs_right = []
|
||||
num_scheduled_tokens_left = {}
|
||||
num_scheduled_tokens_right = {}
|
||||
total_num_scheduled_tokens_left = 0
|
||||
total_num_scheduled_tokens_right = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if new_req.req_id in input_split.req_ids_left:
|
||||
new_req_data_left.append(new_req)
|
||||
else:
|
||||
new_req_data_right.append(new_req)
|
||||
|
||||
cached_reqs_left = CachedRequestData.make_empty()
|
||||
cached_reqs_right = CachedRequestData.make_empty()
|
||||
for req_idx, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids):
|
||||
if req_id in input_split.req_ids_left:
|
||||
cached_reqs_left.req_ids.append(req_id)
|
||||
cached_reqs_left.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx])
|
||||
if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0:
|
||||
cached_reqs_left.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx])
|
||||
cached_reqs_left.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx])
|
||||
cached_reqs_left.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx])
|
||||
else:
|
||||
cached_reqs_right.req_ids.append(req_id)
|
||||
cached_reqs_right.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx])
|
||||
if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0:
|
||||
cached_reqs_right.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx])
|
||||
cached_reqs_right.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx])
|
||||
cached_reqs_right.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx])
|
||||
for key, value in scheduler_output.num_scheduled_tokens.items():
|
||||
if key in input_split.req_ids_left:
|
||||
num_scheduled_tokens_left[key] = value
|
||||
total_num_scheduled_tokens_left += value
|
||||
else:
|
||||
num_scheduled_tokens_right[key] = value
|
||||
total_num_scheduled_tokens_right += value
|
||||
|
||||
|
||||
input_split.scheduler_output_left = SchedulerOutput(
|
||||
scheduled_new_reqs=new_req_data_left,
|
||||
scheduled_cached_reqs=cached_reqs_left,
|
||||
num_scheduled_tokens=num_scheduled_tokens_left,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens_left,
|
||||
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet
|
||||
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=scheduler_output.finished_req_ids,
|
||||
free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
|
||||
structured_output_request_ids=scheduler_output.structured_output_request_ids,
|
||||
grammar_bitmask=scheduler_output.grammar_bitmask,
|
||||
)
|
||||
input_split.scheduler_output_right = SchedulerOutput(
|
||||
scheduled_new_reqs=new_req_data_right,
|
||||
scheduled_cached_reqs=cached_reqs_right,
|
||||
num_scheduled_tokens=num_scheduled_tokens_right,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens_right,
|
||||
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet
|
||||
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=scheduler_output.finished_req_ids,
|
||||
free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
|
||||
structured_output_request_ids=scheduler_output.structured_output_request_ids,
|
||||
grammar_bitmask=scheduler_output.grammar_bitmask,
|
||||
)
|
||||
|
||||
|
||||
def prepare_tbo_atten_metadata(
|
||||
runner,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
req_ids,
|
||||
req_offset
|
||||
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = len(req_ids)
|
||||
assert num_reqs > 0
|
||||
|
||||
seq_len_offset = req_offset
|
||||
# Get the number of scheduled tokens for each request.
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
if req_offset > 0: #right
|
||||
if input_split.query_start_loc_right == None:
|
||||
# TODO: create when system init
|
||||
input_split.query_start_loc_right = torch.zeros(runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=runner.device)
|
||||
|
||||
cu_num_tokens, arange = runner._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Prepare the attention metadata.
|
||||
runner.query_start_loc_np[0] = 0
|
||||
runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
|
||||
|
||||
input_split.query_start_loc_right[0: num_reqs + 1].copy_(
|
||||
runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
input_split.query_start_loc_right[num_reqs + 1:].fill_(
|
||||
runner.query_start_loc_cpu[num_reqs].item())
|
||||
query_start_loc = input_split.query_start_loc_right[: num_reqs + 1]
|
||||
|
||||
|
||||
else:
|
||||
query_start_loc = runner.query_start_loc[:num_reqs + 1]
|
||||
|
||||
|
||||
seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
runner.kv_cache_config.kv_cache_groups):
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
metadata_builder = runner.attn_metadata_builders[kv_cache_group_id]
|
||||
if runner.cascade_attn_enabled:
|
||||
common_prefix_len = runner._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
metadata_builder,
|
||||
)
|
||||
if req_offset > 0:
|
||||
origin_block_table = metadata_builder.block_table.block_table
|
||||
metadata_builder.block_table.block_table = origin_block_table[req_offset:, :]
|
||||
origin_slot_mapping = metadata_builder.block_table.slot_mapping
|
||||
metadata_builder.block_table.slot_mapping = \
|
||||
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:]
|
||||
origin_slot_map_cpu = metadata_builder.block_table.slot_mapping_cpu
|
||||
metadata_builder.block_table.slot_mapping_cpu = \
|
||||
origin_slot_map_cpu[input_split.scheduler_output_left.total_num_scheduled_tokens:]
|
||||
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
|
||||
_num_decodes_record = metadata_builder._num_decodes
|
||||
_num_prefills_record = metadata_builder._num_prefills
|
||||
_num_decode_tokens_record = metadata_builder._num_decode_tokens
|
||||
_num_prefill_tokens_record = metadata_builder._num_prefill_tokens
|
||||
|
||||
metadata_builder._num_decodes = 0
|
||||
metadata_builder._num_prefills = num_reqs
|
||||
metadata_builder._num_decode_tokens = 0
|
||||
metadata_builder._num_prefill_tokens = total_num_scheduled_tokens
|
||||
attn_metadata_i = (
|
||||
metadata_builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata
|
||||
if req_offset > 0:
|
||||
metadata_builder.block_table.block_table = origin_block_table
|
||||
metadata_builder.block_table.slot_mapping = origin_slot_mapping
|
||||
metadata_builder.block_table.slot_mapping_cpu = origin_slot_map_cpu
|
||||
|
||||
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only
|
||||
metadata_builder._num_decodes = _num_decodes_record
|
||||
metadata_builder._num_prefills = _num_prefills_record
|
||||
metadata_builder._num_decode_tokens = _num_decode_tokens_record
|
||||
metadata_builder._num_prefill_tokens = _num_prefill_tokens_record
|
||||
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def pad_num_input_tokens(self, scheduler_output):
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
if self.vllm_config.compilation_config.pass_config. \
|
||||
enable_sequence_parallelism and tp_size > 1:
|
||||
from vllm.utils import round_up
|
||||
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
|
||||
else:
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
return num_input_tokens, num_tokens_across_dp
|
||||
|
||||
def tbo_split_and_execute_model(
|
||||
runner,
|
||||
attn_metadata,
|
||||
num_input_tokens,
|
||||
num_tokens_across_dp,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
skip_cuda_graphs: bool = True,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
use_tbo = False
|
||||
if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \
|
||||
runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode
|
||||
use_tbo = False
|
||||
else:
|
||||
if len(scheduler_output.num_scheduled_tokens) > 1 and num_input_tokens > envs.VLLM_TBO_MIN_TOKENS:
|
||||
split_scheduler_output(runner, scheduler_output)
|
||||
use_tbo = True
|
||||
if use_tbo:
|
||||
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
|
||||
num_input_tokens_right = num_input_tokens - num_input_tokens_left
|
||||
|
||||
attn_metadata_left = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_left, input_split.req_ids_left, 0)
|
||||
attn_metadata_right = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_right, input_split.req_ids_right, input_split.req_num_left)
|
||||
|
||||
with set_forward_context(attn_metadata,
|
||||
runner.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=True):
|
||||
runner.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
model_output = tbo_model_executable_v1(
|
||||
runner,
|
||||
attn_metadata_left,
|
||||
attn_metadata_right,
|
||||
num_input_tokens_left,
|
||||
num_input_tokens_right,
|
||||
num_tokens_across_dp,
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
runner.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
runner.get_finished_kv_transfers(scheduler_output))
|
||||
#finished_sending, finished_recving = None, None
|
||||
else:
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
envs.VLLM_ENABLE_TBO = False
|
||||
with set_forward_context(attn_metadata,
|
||||
runner.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs):
|
||||
runner.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
model_output = runner.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
runner.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
runner.get_finished_kv_transfers(scheduler_output))
|
||||
envs.VLLM_ENABLE_TBO = True
|
||||
return model_output, finished_sending, finished_recving
|
||||
235
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
Normal file
235
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.profiler.prof import profile
|
||||
from vllm import envs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
tbo_step_stream = None
|
||||
all_reduce_stream = None
|
||||
|
||||
class TwoBatchOverlap():
|
||||
def __init__(self):
|
||||
global tbo_step_stream
|
||||
global all_reduce_stream
|
||||
self.model_input_left_queue = queue.Queue()
|
||||
self.model_input_right_queue = queue.Queue()
|
||||
self.states_left_queue = queue.Queue()
|
||||
self.states_right_queue = queue.Queue()
|
||||
self.left_thread = None
|
||||
self.right_thread = None
|
||||
self.left_tid = 0
|
||||
self.right_tid = 0
|
||||
self.sem_left = threading.Semaphore(0)
|
||||
self.sem_right = threading.Semaphore(0)
|
||||
self.left_first = False
|
||||
self.tbo_running = False
|
||||
self.tbo_in_capture = False
|
||||
if tbo_step_stream == None:
|
||||
tbo_step_stream = torch.cuda.Stream()
|
||||
all_reduce_stream = torch.cuda.Stream()
|
||||
self.step_event = torch.cuda.Event(enable_timing=False)
|
||||
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
|
||||
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
|
||||
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
|
||||
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
|
||||
|
||||
def init_tbo_thread(self):
|
||||
self.model_input_left_queue.empty()
|
||||
self.model_input_right_queue.empty()
|
||||
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
|
||||
self.left_thread.start()
|
||||
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
|
||||
self.right_thread.start()
|
||||
if get_tp_group().rank == 0:
|
||||
logger.info('tbo:two batch overlap start')
|
||||
|
||||
def finish_thread(self):
|
||||
self.left_thread.join()
|
||||
self.left_thread = None
|
||||
self.right_thread.join()
|
||||
self.right_thread = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def thread_two_batch_overlap(self, queue):
|
||||
is_left_thread = False
|
||||
tid = threading.get_ident()
|
||||
if queue == self.model_input_left_queue:
|
||||
self.left_tid = tid
|
||||
is_left_thread = True
|
||||
init_tbo_forward_context(True, self.left_tid)
|
||||
else:
|
||||
self.right_tid = tid
|
||||
init_tbo_forward_context(False, self.right_tid)
|
||||
with torch.cuda.stream(tbo_step_stream):
|
||||
queue.get()
|
||||
self.tbo_thread_synchronize(tid)
|
||||
if is_left_thread:
|
||||
attn_metadata = self.attn_metadata_left
|
||||
num_input_tokens = self.num_input_tokens_left
|
||||
input_ids = self.input_ids_left
|
||||
positions = self.positions_left
|
||||
else:
|
||||
attn_metadata = self.attn_metadata_right
|
||||
num_input_tokens = self.num_input_tokens_right
|
||||
input_ids = self.input_ids_right
|
||||
positions = self.positions_right
|
||||
|
||||
model_output = None
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata,
|
||||
self.model_runner.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=self.num_tokens_across_dp,
|
||||
skip_cuda_graphs=True):
|
||||
model_output = self.model_runner.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=self.intermediate_tensors,
|
||||
inputs_embeds=self.inputs_embeds,
|
||||
)
|
||||
if is_left_thread:
|
||||
self.sem_right.release()
|
||||
self.states_left_queue.put(model_output)
|
||||
else:
|
||||
self.states_right_queue.put(model_output)
|
||||
|
||||
def tbo_thread_synchronize(self, tid):
|
||||
if tid == self.left_tid:
|
||||
if not self.left_first:
|
||||
self.sem_right.release()
|
||||
self.left_first = False
|
||||
self.sem_left.acquire()
|
||||
return self.event_left_c2t, self.event_left_t2c
|
||||
else:
|
||||
self.sem_left.release()
|
||||
self.sem_right.acquire()
|
||||
return self.event_right_c2t, self.event_right_t2c
|
||||
|
||||
def set_model_input(self,
|
||||
model_runner,
|
||||
attn_metadata_left,
|
||||
attn_metadata_right,
|
||||
num_input_tokens_left,
|
||||
num_input_tokens_right,
|
||||
input_ids_left,
|
||||
input_ids_right,
|
||||
positions_left,
|
||||
positions_right,
|
||||
num_tokens_across_dp,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
self.model_runner = model_runner
|
||||
self.attn_metadata_left = attn_metadata_left
|
||||
self.attn_metadata_right = attn_metadata_right
|
||||
self.num_input_tokens_left = num_input_tokens_left
|
||||
self.num_input_tokens_right = num_input_tokens_right
|
||||
self.input_ids_left = input_ids_left
|
||||
self.input_ids_right = input_ids_right
|
||||
self.positions_left = positions_left
|
||||
self.positions_right = positions_right
|
||||
self.num_tokens_across_dp = num_tokens_across_dp
|
||||
self.intermediate_tensors = intermediate_tensors
|
||||
self.inputs_embeds = inputs_embeds
|
||||
|
||||
self.model_input_left_queue.put(None)
|
||||
self.model_input_right_queue.put(None)
|
||||
|
||||
def get_model_output(self):
|
||||
states_left = self.states_left_queue.get()
|
||||
states_right = self.states_right_queue.get()
|
||||
return states_left, states_right
|
||||
|
||||
tbo_obj_v1 = None
|
||||
|
||||
def is_enable_tbo_v1():
|
||||
global tbo_obj_v1
|
||||
return tbo_obj_v1 != None
|
||||
|
||||
def init_two_batch_overlap():
|
||||
global tbo_obj_v1
|
||||
if tbo_obj_v1 == None:
|
||||
tbo_obj_v1 = TwoBatchOverlap()
|
||||
tbo_obj_v1.init_tbo_thread()
|
||||
|
||||
def tbo_all_reduce_v1(obj):
|
||||
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
|
||||
tid = threading.get_ident()
|
||||
if tid == tbo_obj_v1.left_tid:
|
||||
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
|
||||
else:
|
||||
event_c2t, event_t2c = tbo_obj_v1.event_right_c2t, tbo_obj_v1.event_right_t2c
|
||||
event_c2t.record()
|
||||
with torch.cuda.stream(all_reduce_stream):
|
||||
all_reduce_stream.wait_event(event_c2t)
|
||||
output = tensor_model_parallel_all_reduce(obj)
|
||||
event_t2c.record()
|
||||
tbo_obj_v1.tbo_thread_synchronize(tid)
|
||||
tbo_step_stream.wait_event(event_t2c)
|
||||
return output
|
||||
return tensor_model_parallel_all_reduce(obj)
|
||||
|
||||
def merge_model_output(states_left, states_right):
|
||||
if isinstance(states_left, IntermediateTensors):
|
||||
output_map = {}
|
||||
for key in states_left.tensors:
|
||||
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
|
||||
output = IntermediateTensors(output_map)
|
||||
else:
|
||||
output = torch.concat([states_left, states_right], dim=0)
|
||||
return output
|
||||
|
||||
def tbo_model_executable_v1(
|
||||
model_runner,
|
||||
attn_metadata_left,
|
||||
attn_metadata_right,
|
||||
num_input_tokens_left,
|
||||
num_input_tokens_right,
|
||||
num_tokens_across_dp,
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds
|
||||
):
|
||||
|
||||
init_two_batch_overlap()
|
||||
tbo_obj_v1.tbo_running = True
|
||||
tbo_obj_v1.left_first = True
|
||||
tbo_obj_v1.step_event.record()
|
||||
current_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(tbo_step_stream):
|
||||
tbo_step_stream.wait_event(tbo_obj_v1.step_event)
|
||||
tokens_split = [num_input_tokens_left, num_input_tokens_right]
|
||||
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0)
|
||||
positions_left, positions_right = torch.split(positions, tokens_split, dim=0)
|
||||
tbo_obj_v1.set_model_input(model_runner,
|
||||
attn_metadata_left,
|
||||
attn_metadata_right,
|
||||
num_input_tokens_left,
|
||||
num_input_tokens_right,
|
||||
input_ids_left,
|
||||
input_ids_right,
|
||||
positions_left,
|
||||
positions_right,
|
||||
num_tokens_across_dp,
|
||||
intermediate_tensors,
|
||||
inputs_embeds)
|
||||
model_output_left, model_output_right = tbo_obj_v1.get_model_output()
|
||||
|
||||
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
|
||||
tbo_obj_v1.tbo_running = False
|
||||
tbo_obj_v1.step_event.record()
|
||||
tbo_obj_v1.finish_thread()
|
||||
current_stream.wait_event(tbo_obj_v1.step_event)
|
||||
|
||||
return hidden_or_intermediate_states
|
||||
Reference in New Issue
Block a user