2024-05-14 07:57:00 +08:00
"""
2024-07-20 03:39:50 -07:00
Usage :
pip install opencv - python - headless
python3 srt_example_llava . py
2024-05-14 07:57:00 +08:00
"""
2024-07-18 04:55:39 +10:00
import argparse
2024-05-14 07:57:00 +08:00
import csv
2024-07-18 04:55:39 +10:00
import os
2024-05-14 07:57:00 +08:00
import time
2024-07-18 04:55:39 +10:00
import sglang as sgl
2024-05-14 07:57:00 +08:00
@sgl.function
def video_qa ( s , num_frames , video_path , question ) :
2024-07-18 04:55:39 +10:00
s + = sgl . user ( sgl . video ( video_path , num_frames ) + question )
2024-05-14 07:57:00 +08:00
s + = sgl . assistant ( sgl . gen ( " answer " ) )
def single ( path , num_frames = 16 ) :
state = video_qa . run (
num_frames = num_frames ,
video_path = path ,
question = " Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes " ,
temperature = 0.0 ,
max_new_tokens = 1024 ,
)
print ( state [ " answer " ] , " \n " )
def split_into_chunks ( lst , num_chunks ) :
""" Split a list into a specified number of chunks. """
# Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
chunk_size = len ( lst ) / / num_chunks
if chunk_size == 0 :
chunk_size = len ( lst )
# Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
2024-07-18 04:55:39 +10:00
chunks = [ lst [ i : i + chunk_size ] for i in range ( 0 , len ( lst ) , chunk_size ) ]
2024-05-14 07:57:00 +08:00
# Ensure we have exactly num_chunks chunks, even if some are empty
chunks . extend ( [ [ ] for _ in range ( num_chunks - len ( chunks ) ) ] )
return chunks
def save_batch_results ( batch_video_files , states , cur_chunk , batch_idx , save_dir ) :
csv_filename = f " { save_dir } /chunk_ { cur_chunk } _batch_ { batch_idx } .csv "
2024-07-18 04:55:39 +10:00
with open ( csv_filename , " w " , newline = " " ) as csvfile :
2024-05-14 07:57:00 +08:00
writer = csv . writer ( csvfile )
2024-07-18 04:55:39 +10:00
writer . writerow ( [ " video_name " , " answer " ] )
2024-05-14 07:57:00 +08:00
for video_path , state in zip ( batch_video_files , states ) :
video_name = os . path . basename ( video_path )
writer . writerow ( [ video_name , state [ " answer " ] ] )
2024-07-18 04:55:39 +10:00
2024-05-14 07:57:00 +08:00
def compile_and_cleanup_final_results ( cur_chunk , num_batches , save_dir ) :
final_csv_filename = f " { save_dir } /final_results_chunk_ { cur_chunk } .csv "
2024-07-18 04:55:39 +10:00
with open ( final_csv_filename , " w " , newline = " " ) as final_csvfile :
2024-05-14 07:57:00 +08:00
writer = csv . writer ( final_csvfile )
2024-07-18 04:55:39 +10:00
writer . writerow ( [ " video_name " , " answer " ] )
2024-05-14 07:57:00 +08:00
for batch_idx in range ( num_batches ) :
batch_csv_filename = f " { save_dir } /chunk_ { cur_chunk } _batch_ { batch_idx } .csv "
2024-07-18 04:55:39 +10:00
with open ( batch_csv_filename , " r " ) as batch_csvfile :
2024-05-14 07:57:00 +08:00
reader = csv . reader ( batch_csvfile )
next ( reader ) # Skip header row
for row in reader :
writer . writerow ( row )
os . remove ( batch_csv_filename )
2024-07-18 04:55:39 +10:00
2024-05-14 07:57:00 +08:00
def find_video_files ( video_dir ) :
# Check if the video_dir is actually a file
if os . path . isfile ( video_dir ) :
# If it's a file, return it as a single-element list
return [ video_dir ]
2024-07-18 04:55:39 +10:00
2024-05-14 07:57:00 +08:00
# Original logic to find video files in a directory
video_files = [ ]
for root , dirs , files in os . walk ( video_dir ) :
for file in files :
2024-07-18 04:55:39 +10:00
if file . endswith ( ( " .mp4 " , " .avi " , " .mov " ) ) :
2024-05-14 07:57:00 +08:00
video_files . append ( os . path . join ( root , file ) )
return video_files
2024-07-18 04:55:39 +10:00
2024-05-14 07:57:00 +08:00
def batch ( video_dir , save_dir , cur_chunk , num_chunks , num_frames = 16 , batch_size = 64 ) :
video_files = find_video_files ( video_dir )
chunked_video_files = split_into_chunks ( video_files , num_chunks ) [ cur_chunk ]
num_batches = 0
for i in range ( 0 , len ( chunked_video_files ) , batch_size ) :
2024-07-18 04:55:39 +10:00
batch_video_files = chunked_video_files [ i : i + batch_size ]
2024-05-14 07:57:00 +08:00
print ( f " Processing batch of { len ( batch_video_files ) } video(s)... " )
if not batch_video_files :
print ( " No video files found in the specified directory. " )
return
2024-07-18 04:55:39 +10:00
2024-05-14 07:57:00 +08:00
batch_input = [
2024-07-18 04:55:39 +10:00
{
2024-05-14 07:57:00 +08:00
" num_frames " : num_frames ,
" video_path " : video_path ,
" question " : " Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes. " ,
2024-07-18 04:55:39 +10:00
}
for video_path in batch_video_files
2024-05-14 07:57:00 +08:00
]
start_time = time . time ( )
states = video_qa . run_batch ( batch_input , max_new_tokens = 512 , temperature = 0.2 )
total_time = time . time ( ) - start_time
average_time = total_time / len ( batch_video_files )
2024-07-18 04:55:39 +10:00
print (
f " Number of videos in batch: { len ( batch_video_files ) } . Average processing time per video: { average_time : .2f } seconds. Total time for this batch: { total_time : .2f } seconds "
)
2024-05-14 07:57:00 +08:00
save_batch_results ( batch_video_files , states , cur_chunk , num_batches , save_dir )
num_batches + = 1
compile_and_cleanup_final_results ( cur_chunk , num_batches , save_dir )
if __name__ == " __main__ " :
# Create the parser
2024-07-18 04:55:39 +10:00
parser = argparse . ArgumentParser (
description = " Run video processing with specified port. "
)
2024-05-14 07:57:00 +08:00
# Add an argument for the port
2024-07-18 04:55:39 +10:00
parser . add_argument (
" --port " ,
type = int ,
default = 30000 ,
help = " The master port for distributed serving. " ,
)
parser . add_argument (
" --chunk-idx " , type = int , default = 0 , help = " The index of the chunk to process. "
)
parser . add_argument (
" --num-chunks " , type = int , default = 8 , help = " The number of chunks to process. "
)
parser . add_argument (
" --save-dir " ,
type = str ,
default = " ./work_dirs/llava_video " ,
help = " The directory to save the processed video files. " ,
)
parser . add_argument (
" --video-dir " ,
type = str ,
default = " ./videos/Q98Z4OTh8RwmDonc.mp4 " ,
help = " The directory or path for the processed video files. " ,
)
parser . add_argument (
" --model-path " ,
type = str ,
default = " lmms-lab/LLaVA-NeXT-Video-7B " ,
help = " The model path for the video processing. " ,
)
parser . add_argument (
" --num-frames " ,
type = int ,
default = 16 ,
help = " The number of frames to process in each video. " ,
)
2024-05-14 07:57:00 +08:00
parser . add_argument ( " --mm_spatial_pool_stride " , type = int , default = 2 )
# Parse the arguments
args = parser . parse_args ( )
cur_port = args . port
cur_chunk = args . chunk_idx
num_chunks = args . num_chunks
num_frames = args . num_frames
if " 34b " in args . model_path . lower ( ) :
tokenizer_path = " liuhaotian/llava-v1.6-34b-tokenizer "
elif " 7b " in args . model_path . lower ( ) :
tokenizer_path = " llava-hf/llava-1.5-7b-hf "
else :
print ( " Invalid model path. Please specify a valid model path. " )
exit ( )
model_overide_args = { }
model_overide_args [ " mm_spatial_pool_stride " ] = args . mm_spatial_pool_stride
model_overide_args [ " architectures " ] = [ " LlavaVidForCausalLM " ]
model_overide_args [ " num_frames " ] = args . num_frames
model_overide_args [ " model_type " ] = " llava "
if " 34b " in args . model_path . lower ( ) :
model_overide_args [ " image_token_index " ] = 64002
if args . num_frames == 32 :
model_overide_args [ " rope_scaling " ] = { " factor " : 2.0 , " type " : " linear " }
model_overide_args [ " max_sequence_length " ] = 4096 * 2
model_overide_args [ " tokenizer_model_max_length " ] = 4096 * 2
elif args . num_frames < 32 :
pass
else :
2024-07-18 04:55:39 +10:00
print (
" The maximum number of frames to process is 32. Please specify a valid number of frames. "
)
2024-05-14 07:57:00 +08:00
exit ( )
runtime = sgl . Runtime (
2024-07-18 04:55:39 +10:00
model_path = args . model_path , # "liuhaotian/llava-v1.6-vicuna-7b",
2024-05-14 07:57:00 +08:00
tokenizer_path = tokenizer_path ,
port = cur_port ,
2024-07-18 04:55:39 +10:00
additional_ports = [ cur_port + 1 , cur_port + 2 , cur_port + 3 , cur_port + 4 ] ,
2024-05-14 07:57:00 +08:00
model_overide_args = model_overide_args ,
2024-07-18 04:55:39 +10:00
tp_size = 1 ,
2024-05-14 07:57:00 +08:00
)
sgl . set_default_backend ( runtime )
print ( f " chat template: { runtime . endpoint . chat_template . name } " )
# Run a single request
# try:
print ( " \n ========== single ========== \n " )
root = args . video_dir
if os . path . isfile ( root ) :
video_files = [ root ]
else :
2024-07-18 04:55:39 +10:00
video_files = [
os . path . join ( root , f )
for f in os . listdir ( root )
if f . endswith ( ( " .mp4 " , " .avi " , " .mov " ) )
] # Add more extensions if needed
2024-05-14 07:57:00 +08:00
start_time = time . time ( ) # Start time for processing a single video
for cur_video in video_files [ : 1 ] :
print ( cur_video )
single ( cur_video , num_frames )
end_time = time . time ( ) # End time for processing a single video
total_time = end_time - start_time
2024-07-18 04:55:39 +10:00
average_time = total_time / len (
video_files
) # Calculate the average processing time
2024-05-14 07:57:00 +08:00
print ( f " Average processing time per video: { average_time : .2f } seconds " )
runtime . shutdown ( )
# except Exception as e:
# print(e)
runtime . shutdown ( )
# # # Run a batch of requests
# print("\n========== batch ==========\n")
# if not os.path.exists(args.save_dir):
# os.makedirs(args.save_dir)
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
2024-07-18 04:55:39 +10:00
# runtime.shutdown()