support llava video (#426)

This commit is contained in:
Yuanhan Zhang
2024-05-14 07:57:00 +08:00
committed by GitHub
parent 5dc55a5f02
commit 0992d85f92
37 changed files with 1139 additions and 222 deletions

4
.gitignore vendored
View File

@@ -177,3 +177,7 @@ tmp*.txt
# Plots
*.png
*.pdf
# personnal
work_dirs/
*.csv

View File

@@ -410,4 +410,4 @@ https://github.com/sgl-project/sglang/issues/157
}
```
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).

View File

@@ -14,7 +14,7 @@ def single():
state = image_qa.run(
image_path="images/cat.jpeg",
question="What is this?",
max_new_tokens=64)
max_new_tokens=128)
print(state["answer"], "\n")
@@ -36,7 +36,7 @@ def batch():
{"image_path": "images/cat.jpeg", "question":"What is this?"},
{"image_path": "images/dog.jpeg", "question":"What is this?"},
],
max_new_tokens=64,
max_new_tokens=128,
)
for s in states:
print(s["answer"], "\n")

View File

@@ -0,0 +1,208 @@
"""
Usage: python3 srt_example_llava.py
"""
import sglang as sgl
import os
import csv
import time
import argparse
@sgl.function
def video_qa(s, num_frames, video_path, question):
s += sgl.user(sgl.video(video_path,num_frames) + question)
s += sgl.assistant(sgl.gen("answer"))
def single(path, num_frames=16):
state = video_qa.run(
num_frames=num_frames,
video_path=path,
question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes",
temperature=0.0,
max_new_tokens=1024,
)
print(state["answer"], "\n")
def split_into_chunks(lst, num_chunks):
"""Split a list into a specified number of chunks."""
# Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
chunk_size = len(lst) // num_chunks
if chunk_size == 0:
chunk_size = len(lst)
# Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
# Ensure we have exactly num_chunks chunks, even if some are empty
chunks.extend([[] for _ in range(num_chunks - len(chunks))])
return chunks
def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
with open(csv_filename, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['video_name', 'answer'])
for video_path, state in zip(batch_video_files, states):
video_name = os.path.basename(video_path)
writer.writerow([video_name, state["answer"]])
def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
with open(final_csv_filename, 'w', newline='') as final_csvfile:
writer = csv.writer(final_csvfile)
writer.writerow(['video_name', 'answer'])
for batch_idx in range(num_batches):
batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
with open(batch_csv_filename, 'r') as batch_csvfile:
reader = csv.reader(batch_csvfile)
next(reader) # Skip header row
for row in reader:
writer.writerow(row)
os.remove(batch_csv_filename)
def find_video_files(video_dir):
# Check if the video_dir is actually a file
if os.path.isfile(video_dir):
# If it's a file, return it as a single-element list
return [video_dir]
# Original logic to find video files in a directory
video_files = []
for root, dirs, files in os.walk(video_dir):
for file in files:
if file.endswith(('.mp4', '.avi', '.mov')):
video_files.append(os.path.join(root, file))
return video_files
def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
video_files = find_video_files(video_dir)
chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
num_batches = 0
for i in range(0, len(chunked_video_files), batch_size):
batch_video_files = chunked_video_files[i:i + batch_size]
print(f"Processing batch of {len(batch_video_files)} video(s)...")
if not batch_video_files:
print("No video files found in the specified directory.")
return
batch_input = [
{
"num_frames": num_frames,
"video_path": video_path,
"question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
} for video_path in batch_video_files
]
start_time = time.time()
states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
total_time = time.time() - start_time
average_time = total_time / len(batch_video_files)
print(f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds")
save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
num_batches += 1
compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)
if __name__ == "__main__":
# Create the parser
parser = argparse.ArgumentParser(description='Run video processing with specified port.')
# Add an argument for the port
parser.add_argument('--port', type=int, default=30000, help='The master port for distributed serving.')
parser.add_argument('--chunk-idx', type=int, default=0, help='The index of the chunk to process.')
parser.add_argument('--num-chunks', type=int, default=8, help='The number of chunks to process.')
parser.add_argument('--save-dir', type=str, default="./work_dirs/llava_video", help='The directory to save the processed video files.')
parser.add_argument('--video-dir', type=str, default="./videos/Q98Z4OTh8RwmDonc.mp4", help='The directory or path for the processed video files.')
parser.add_argument('--model-path', type=str, default="lmms-lab/LLaVA-NeXT-Video-7B", help='The model path for the video processing.')
parser.add_argument('--num-frames', type=int, default=16, help='The number of frames to process in each video.' )
parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)
# Parse the arguments
args = parser.parse_args()
cur_port = args.port
cur_chunk = args.chunk_idx
num_chunks = args.num_chunks
num_frames = args.num_frames
if "34b" in args.model_path.lower():
tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer"
elif "7b" in args.model_path.lower():
tokenizer_path = "llava-hf/llava-1.5-7b-hf"
else:
print("Invalid model path. Please specify a valid model path.")
exit()
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = args.num_frames
model_overide_args["model_type"] = "llava"
if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002
if args.num_frames == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
elif args.num_frames < 32:
pass
else:
print("The maximum number of frames to process is 32. Please specify a valid number of frames.")
exit()
runtime = sgl.Runtime(
model_path=args.model_path, #"liuhaotian/llava-v1.6-vicuna-7b",
tokenizer_path=tokenizer_path,
port=cur_port,
additional_ports=[cur_port+1,cur_port+2,cur_port+3,cur_port+4],
model_overide_args=model_overide_args,
tp_size=1
)
sgl.set_default_backend(runtime)
print(f"chat template: {runtime.endpoint.chat_template.name}")
# Run a single request
# try:
print("\n========== single ==========\n")
root = args.video_dir
if os.path.isfile(root):
video_files = [root]
else:
video_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.mp4', '.avi', '.mov'))] # Add more extensions if needed
start_time = time.time() # Start time for processing a single video
for cur_video in video_files[:1]:
print(cur_video)
single(cur_video, num_frames)
end_time = time.time() # End time for processing a single video
total_time = end_time - start_time
average_time = total_time / len(video_files) # Calculate the average processing time
print(f"Average processing time per video: {average_time:.2f} seconds")
runtime.shutdown()
# except Exception as e:
# print(e)
runtime.shutdown()
# # # Run a batch of requests
# print("\n========== batch ==========\n")
# if not os.path.exists(args.save_dir):
# os.makedirs(args.save_dir)
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
# runtime.shutdown()

View File

@@ -0,0 +1,130 @@
#!/bin/bash
##### USAGE #####
# - First node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# - Second node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# - The K node:
# ```sh
# bash examples/quick_start/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
# ```
# Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details.
CURRENT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
echo ${CURRENT_ROOT}
cd ${CURRENT_ROOT}
export PYTHONWARNINGS=ignore
START_TIME=$(date +%s) # Capture start time
NUM_NODES=$1
CUR_NODES_IDX=$2
VIDEO_DIR=$3
MODEL_PATH=$4
NUM_FRAMES=$5
# FRAME_FORMAT=$6
# FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]')
# # Check if FRAME_FORMAT is either JPEG or PNG
# if [[ "$FRAME_FORMAT" != "JPEG" && "$FRAME_FORMAT" != "PNG" ]]; then
# echo "Error: FRAME_FORMAT must be either JPEG or PNG."
# exit 1
# fi
# export TARGET_FRAMES=$TARGET_FRAMES
echo "Each video you will sample $NUM_FRAMES frames"
# export FRAME_FORMAT=$FRAME_FORMAT
# echo "The frame format is $FRAME_FORMAT"
# Assuming GPULIST is a bash array containing your GPUs
GPULIST=(0 1 2 3 4 5 6 7)
LOCAL_CHUNKS=${#GPULIST[@]}
echo "Number of GPUs in GPULIST: $LOCAL_CHUNKS"
ALL_CHUNKS=$((NUM_NODES * LOCAL_CHUNKS))
# Calculate GPUs per chunk
GPUS_PER_CHUNK=8
echo $GPUS_PER_CHUNK
for IDX in $(seq 1 $LOCAL_CHUNKS); do
(
START=$(((IDX-1) * GPUS_PER_CHUNK))
LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
# Convert the chunk GPUs array to a comma-separated string
CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
LOCAL_IDX=$((CUR_NODES_IDX * LOCAL_CHUNKS + IDX))
echo "Chunk $(($LOCAL_IDX - 1)) will run on GPUs $CHUNK_GPUS_STR"
# Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk.
PORT=$((10000 + RANDOM % 55536))
MAX_RETRIES=10
RETRY_COUNT=0
COMMAND_STATUS=1 # Initialize as failed
while [ $RETRY_COUNT -lt $MAX_RETRIES ] && [ $COMMAND_STATUS -ne 0 ]; do
echo "Running chunk $(($LOCAL_IDX - 1)) on GPUs $CHUNK_GPUS_STR with port $PORT. Attempt $(($RETRY_COUNT + 1))"
#!/bin/bash
CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 examples/usage/llava_video/srt_example_llava_v.py \
--port $PORT \
--num-chunks $ALL_CHUNKS \
--chunk-idx $(($LOCAL_IDX - 1)) \
--save-dir work_dirs/llava_next_video_inference_results \
--video-dir $VIDEO_DIR \
--model-path $MODEL_PATH \
--num-frames $NUM_FRAMES #&
wait $! # Wait for the process to finish and capture its exit status
COMMAND_STATUS=$?
if [ $COMMAND_STATUS -ne 0 ]; then
echo "Execution failed for chunk $(($LOCAL_IDX - 1)), attempt $(($RETRY_COUNT + 1)). Retrying..."
RETRY_COUNT=$(($RETRY_COUNT + 1))
sleep 180 # Wait a bit before retrying
else
echo "Execution succeeded for chunk $(($LOCAL_IDX - 1))."
fi
done
if [ $COMMAND_STATUS -ne 0 ]; then
echo "Execution failed for chunk $(($LOCAL_IDX - 1)) after $MAX_RETRIES attempts."
fi
) #&
sleep 2 # Slight delay to stagger the start times
done
wait
cat work_dirs/llava_next_video_inference_results/final_results_chunk_*.csv > work_dirs/llava_next_video_inference_results/final_results_node_${CUR_NODES_IDX}.csv
END_TIME=$(date +%s) # Capture end time
ELAPSED_TIME=$(($END_TIME - $START_TIME))
echo "Total execution time: $ELAPSED_TIME seconds."

Binary file not shown.

View File

@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "outlines>=0.0.27", "packaging"]
"zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"]
openai = ["openai>=1.0", "numpy", "tiktoken"]
anthropic = ["anthropic>=0.20.0", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
@@ -33,4 +33,4 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]

View File

@@ -19,6 +19,7 @@ from sglang.api import (
user,
user_begin,
user_end,
video,
)
# SGL Backends
@@ -46,6 +47,7 @@ __all__ = [
"gen_int",
"gen_string",
"image",
"video",
"select",
"system",
"user",

View File

@@ -15,6 +15,7 @@ from sglang.lang.ir import (
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVideo,
)
@@ -151,6 +152,10 @@ def image(expr: SglExpr):
return SglImage(expr)
def video(path: str, num_frames: int):
return SglVideo(path, num_frames)
def select(
name: Optional[str] = None,
choices: List[str] = None,

View File

@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
return get_chat_template("vicuna_v1.1")
if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava-next-video-7b" in model_path.lower():
return get_chat_template("vicuna_v1.1")
@register_chat_template_matching_function
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
if "tinyllama" in model_path:
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
):
return get_chat_template("chatml-llava")
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
if "yi" in model_path:
if "yi" in model_path and "llava" not in model_path:
return get_chat_template("yi")

View File

@@ -28,8 +28,9 @@ from sglang.lang.ir import (
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
SglVideo,
)
from sglang.utils import encode_image_base64, get_exception_traceback
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
def run_internal(state, program, func_args, func_kwargs, sync):
@@ -361,6 +362,8 @@ class StreamExecutor:
self._execute_role_end(other)
elif isinstance(other, SglImage):
self._execute_image(other)
elif isinstance(other, SglVideo):
self._execute_video(other)
elif isinstance(other, SglVariable):
self._execute_variable(other)
elif isinstance(other, SglVarScopeBegin):
@@ -397,6 +400,16 @@ class StreamExecutor:
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
def _execute_video(self, expr: SglVideo):
path = expr.path
num_frames = expr.num_frames
base64_data = encode_video_base64(path, num_frames)
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
# if global_config.eager_fill_image:
# self.backend.fill_image(self)

View File

@@ -330,6 +330,15 @@ class SglImage(SglExpr):
return f"SglImage({self.path})"
class SglVideo(SglExpr):
def __init__(self, path, num_frames):
self.path = path
self.num_frames = num_frames
def __repr__(self) -> str:
return f"SglVideo({self.path}, {self.num_frames})"
class SglGen(SglExpr):
def __init__(
self,

View File

@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
##################################
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
assert (size >= 1)
assert size >= 1
if self.only_trace_prefix:
raise StopTracing()

View File

@@ -2,11 +2,10 @@ import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, None)
launch_server(server_args, None)

View File

@@ -0,0 +1,31 @@
import argparse
import multiprocessing as mp
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = 16
model_overide_args["model_type"] = "llavavid"
if model_overide_args["num_frames"] == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
model_overide_args["model_max_length"] = 4096 * 2
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002
server_args = ServerArgs.from_cli_args(args)
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
launch_server(server_args, pipe_writer, model_overide_args)

View File

@@ -30,10 +30,17 @@ def get_config_json(model_path: str):
return config
def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
def get_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
model_overide_args: Optional[dict] = None,
):
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision
)
if model_overide_args:
config.update(model_overide_args)
return config

View File

@@ -60,9 +60,7 @@ class RouterManager:
def start_router_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
):
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
@@ -70,7 +68,7 @@ def start_router_process(
)
try:
model_client = ModelRpcClient(server_args, port_args)
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
router = RouterManager(model_client, port_args)
except Exception:
pipe_writer.send(get_exception_traceback())

View File

@@ -4,12 +4,13 @@ import multiprocessing
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import List
from typing import Any, Dict, List, Optional, Tuple, Union
import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
try:
from vllm.logger import _default_handler as vllm_default_logger
except ImportError:
@@ -48,6 +49,7 @@ class ModelRpcServer:
tp_rank: int,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: Optional[dict] = None,
):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
@@ -62,6 +64,7 @@ class ModelRpcServer:
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
model_overide_args=model_overide_args,
)
# For model end global settings
@@ -673,13 +676,15 @@ class ModelRpcService(rpyc.Service):
class ModelRpcClient:
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
def __init__(
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
):
tp_size = server_args.tp_size
if tp_size == 1:
# Init model
self.model_server = ModelRpcService().exposed_ModelRpcServer(
0, server_args, port_args
0, server_args, port_args, model_overide_args
)
# Wrap functions
@@ -700,7 +705,7 @@ class ModelRpcClient:
# Init model
def init_model(i):
return self.remote_services[i].ModelRpcServer(
i, server_args, port_args
i, server_args, port_args, model_overide_args
)
self.model_servers = executor.map(init_model, range(tp_size))
@@ -723,7 +728,11 @@ def _init_service(port):
t = ThreadedServer(
ModelRpcService(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 1800,
},
)
t.start()
@@ -739,7 +748,11 @@ def start_model_process(port):
con = rpyc.connect(
"localhost",
port,
config={"allow_pickle": True, "sync_request_timeout": 1800},
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 1800,
},
)
break
except ConnectionRefusedError:

View File

@@ -9,11 +9,11 @@ from typing import List
import numpy as np
import torch
from vllm.distributed import initialize_model_parallel
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.distributed import initialize_model_parallel
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
@@ -143,7 +143,7 @@ class InputMetadata:
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim
self.model_runner.model_config.head_dim,
]
self.prefill_wrapper.begin_forward(*args)

View File

@@ -60,21 +60,29 @@ def get_pixel_values(
):
try:
processor = processor or global_processor
image = load_image(image_data)
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
image, image_size = load_image(image_data)
if image_size != None:
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
@@ -84,6 +92,7 @@ class TokenizerManager:
self,
server_args: ServerArgs,
port_args: PortArgs,
model_overide_args: dict = None,
):
self.server_args = server_args
@@ -96,7 +105,9 @@ class TokenizerManager:
self.model_path = server_args.model_path
self.hf_config = get_config(
self.model_path, trust_remote_code=server_args.trust_remote_code
self.model_path,
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
self.context_len = get_context_length(self.hf_config)

View File

@@ -10,12 +10,16 @@ class ModelConfig:
trust_remote_code: bool = True,
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_overide_args: Optional[dict] = None,
) -> None:
self.path = path
self.trust_remote_code = trust_remote_code
self.revision = revision
self.hf_config = get_config(self.path, trust_remote_code, revision)
if model_overide_args is not None:
self.hf_config.update(model_overide_args)
if context_length is not None:
self.context_len = context_length
else:

View File

@@ -27,29 +27,25 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
@torch.compile

View File

@@ -5,37 +5,31 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.models.dbrx_config import DbrxConfig
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class DbrxRouter(nn.Module):
@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config)
self.norm_attn_norm = DbrxFusedNormAttention(
config, layer_id, quant_config=quant_config
)
self.ffn = DbrxExperts(config, quant_config=quant_config)
def forward(
@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)]
[
DbrxBlock(config, i, quant_config=quant_config)
for i in range(config.n_layers)
]
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():

View File

@@ -7,6 +7,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class GemmaMLP(nn.Module):
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.act_fn = GeluAndMul()

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlamaMLP(nn.Module):
@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
)
if hidden_act != "silu":
raise ValueError(

View File

@@ -7,12 +7,7 @@ import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaLlamaForCausalLM(nn.Module):

View File

@@ -0,0 +1,307 @@
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
import os
from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
unpad_image_shape,
)
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaVidForCausalLM(nn.Module):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
self.resampler = nn.AvgPool2d(
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
self.num_frames = getattr(self.config, "num_frames", 16)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
# if self.mm_patch_merge_type.startswith("spatial"):
# height = width = self.num_patches_per_side
# if pt_shape[0] > 1:
# if self.image_aspect_ratio == "anyres":
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
# image_size,
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# if "unpad" in self.mm_patch_merge_type:
# h = num_patch_height * height
# w = num_patch_width * width
# new_h, new_w = unpad_image_shape(h, w, image_size)
# new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
# print(input_ids)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
height = width = self.num_patches_per_side
num_of_frames = selected_image_feature.shape[0]
selected_image_feature = selected_image_feature.view(
num_of_frames, height, width, -1
)
selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
selected_image_feature = (
self.resampler(selected_image_feature)
.flatten(2)
.transpose(1, 2)
.contiguous()
)
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
.cpu()
.numpy()
)
# FIXME: We need to substract the length of the system prompt
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
need_vision = need_vision & has_pixel
if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
########## Encode Image ########
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
# image_features = self.encode_images(concat_images)
# split_sizes = [image.shape[0] for image in pixel_values]
# image_features = torch.split(image_features, split_sizes, dim=0)
image_features = self.encode_images(
concat_images
) # , prompts)#, image_counts, long_video=long_video)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(
np.array(pixel_values), device=self.vision_tower.device
)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
new_image_features.append(image_feature.flatten(0, 1))
image_features = new_image_features
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(input_ids, positions, input_metadata)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
print(f"target_frames: {self.num_frames}")
self.image_feature_len = self.num_frames * int(
(self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
)
if self.vision_feature_select_strategy == "patch":
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
# FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
if name in params_dict:
param = params_dict[name]
else:
print(f"Warning: {name} not found in the model")
continue
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward()
@property
def num_patches_per_side(self):
return self.image_size // self.patch_size
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = LlavaVidForCausalLM

View File

@@ -8,34 +8,28 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class MixtralMLP(nn.Module):

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class QWenMLP(nn.Module):
@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
class QWenBlock(nn.Module):
def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
def __init__(
self,
config: PretrainedConfig,
layer_id,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
class QWenModel(nn.Module):
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
class QWenLMHeadModel(nn.Module):
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.transformer = QWenModel(config, quant_config=quant_config)
@@ -276,4 +283,4 @@ class QWenLMHeadModel(nn.Module):
weight_loader(param, loaded_weight)
EntryClass = QWenLMHeadModel
EntryClass = QWenLMHeadModel

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
Qwen2Config = None
@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
)
if hidden_act != "silu":
raise ValueError(

View File

@@ -7,35 +7,31 @@ from typing import Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class StablelmMLP(nn.Module):
def __init__(
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
@@ -48,7 +44,10 @@ class StablelmMLP(nn.Module):
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
config.intermediate_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
)
self.act_fn = SiluAndMul()
@@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
class StableLMEpochModel(nn.Module):
def __init__(
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(

View File

@@ -6,16 +6,13 @@ from typing import List, Optional
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.models.llava import (
LlavaLlamaForCausalLM,
clip_vision_embed_forward,
monkey_path_clip_vision_embed_forward,
)
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class YiVLForCausalLM(LlavaLlamaForCausalLM):

View File

@@ -107,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request)
def launch_server(server_args: ServerArgs, pipe_finish_writer):
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
global tokenizer_manager
logging.basicConfig(
@@ -140,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
)
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args)
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
proc_router = mp.Process(
target=start_router_process,
args=(
server_args,
port_args,
pipe_router_writer,
),
args=(server_args, port_args, pipe_router_writer, model_overide_args),
)
proc_router.start()
proc_detoken = mp.Process(
@@ -170,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
if router_init_state != "init ok" or detoken_init_state != "init ok":
proc_router.kill()
proc_detoken.kill()
print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
print(
f"Initialization failed. router_init_state: {router_init_state}", flush=True
)
print(
f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True,
)
sys.exit(1)
assert proc_router.is_alive() and proc_detoken.is_alive()
@@ -189,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
success = True # Set flag to True if request succeeds
break
except requests.exceptions.RequestException as e:
pass
@@ -205,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
},
},
headers=headers,
timeout=60,
timeout=600,
)
assert res.status_code == 200
except Exception as e:
@@ -235,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
class Runtime:
def __init__(
self,
log_evel="error",
log_evel: str = "error",
model_overide_args: Optional[dict] = None,
*args,
**kwargs,
):
@@ -244,7 +247,10 @@ class Runtime:
# Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
self.server_args.port,
self.server_args.additional_ports,
self.server_args.tp_size,
)
self.url = self.server_args.url()
self.generate_url = (
@@ -253,7 +259,10 @@ class Runtime:
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer, model_overide_args),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
@@ -265,7 +274,9 @@ class Runtime:
if init_state != "init ok":
self.shutdown()
raise RuntimeError("Initialization failed. Please see the error messages above.")
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
@@ -317,4 +328,4 @@ class Runtime:
pos += len(cur)
def __del__(self):
self.shutdown()
self.shutdown()

View File

@@ -80,10 +80,12 @@ class ServerArgs:
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
parser.add_argument("--host", type=str, default=ServerArgs.host,
help="The host of the server.")
parser.add_argument("--port", type=int, default=ServerArgs.port,
help="The port of the server.")
parser.add_argument(
"--host", type=str, default=ServerArgs.host, help="The host of the server."
)
parser.add_argument(
"--port", type=int, default=ServerArgs.port, help="The port of the server."
)
parser.add_argument(
"--additional-ports",
type=int,
@@ -261,4 +263,4 @@ class PortArgs:
router_port: int
detokenizer_port: int
nccl_port: int
model_rpc_ports: List[int]
model_rpc_ports: List[int]

View File

@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
continue
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind(("", port))
s.listen(1) # Attempt to listen on the port
port_list.append(port)
except socket.error:
pass
pass # If any error occurs, this port is not usable
if len(port_list) == num:
return port_list
@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
def is_multimodal_model(model):
if isinstance(model, str):
return "llava" in model or "yi-vl" in model
from sglang.srt.model_config import ModelConfig
if isinstance(model, str):
model = model.lower()
return "llava" in model or "yi-vl" in model or "llava-next" in model
if isinstance(model, ModelConfig):
model_path = model.path.lower()
return "llava" in model_path or "yi-vl" in model_path
raise Exception("unrecognized type")
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
raise ValueError("unrecognized type")
def decode_video_base64(video_base64):
from PIL import Image
# Decode the base64 string
video_bytes = base64.b64decode(video_base64)
# Placeholder for the start indices of each PNG image
img_starts = []
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
assert frame_format in [
"PNG",
"JPEG",
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
if frame_format == "PNG":
# Find each PNG start signature to isolate images
i = 0
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
# Check if we found the start of a PNG file
if (
video_bytes[i] == 0x89
and video_bytes[i + 1] == 0x50
and video_bytes[i + 2] == 0x4E
and video_bytes[i + 3] == 0x47
and video_bytes[i + 4] == 0x0D
and video_bytes[i + 5] == 0x0A
and video_bytes[i + 6] == 0x1A
and video_bytes[i + 7] == 0x0A
):
img_starts.append(i)
i += 8 # Skip the PNG signature
else:
i += 1
else:
# Find each JPEG start (0xFFD8) to isolate images
i = 0
while (
i < len(video_bytes) - 1
): # Adjusted for the length of the JPEG SOI signature
# Check if we found the start of a JPEG file
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
img_starts.append(i)
# Move to the next byte to continue searching for the next image start
i += 2
else:
i += 1
frames = []
for start_idx in img_starts:
# Assuming each image is back-to-back, the end of one image is the start of another
# The last image goes until the end of the byte string
end_idx = (
img_starts[img_starts.index(start_idx) + 1]
if img_starts.index(start_idx) + 1 < len(img_starts)
else len(video_bytes)
)
img_bytes = video_bytes[start_idx:end_idx]
# Convert bytes to a PIL Image
img = Image.open(BytesIO(img_bytes))
# Convert PIL Image to a NumPy array
frame = np.array(img)
# Append the frame to the list of frames
frames.append(frame)
# Ensure there's at least one frame to avoid errors with np.stack
if frames:
return np.stack(frames, axis=0), img.size
else:
return np.array([]), (
0,
0,
) # Return an empty array and size tuple if no frames were found
def load_image(image_file):
from PIL import Image
image = None
image = image_size = None
if image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
@@ -289,10 +373,13 @@ def load_image(image_file):
elif image_file.startswith("data:"):
image_file = image_file.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_file)))
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))
return image
return image, image_size
def assert_pkg_version(pkg: str, min_version: str):
@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
f"is less than the minimum required version {min_version}"
)
except PackageNotFoundError:
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
raise Exception(
f"{pkg} with minimum required version {min_version} is not installed"
)
API_KEY_HEADER_NAME = "X-API-Key"

View File

@@ -19,11 +19,12 @@ import torch
from huggingface_hub import HfFileSystem, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
get_quantization_config,
)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
logger = init_logger(__name__)
@@ -32,17 +33,21 @@ logger = init_logger(__name__)
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
'TEMP') or os.environ.get('TMP') or "/tmp/"
temp_dir = (
os.environ.get("TMPDIR")
or os.environ.get("TEMP")
or os.environ.get("TMP")
or "/tmp/"
)
def enable_hf_transfer():
"""automatically activates hf_transfer
"""
"""automatically activates hf_transfer"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
@@ -65,8 +70,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
mode=0o666)
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
@@ -104,10 +108,12 @@ def convert_bin_to_safetensor_file(
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
"""
)
# check if the tensors are the same
reloaded = load_file(sf_filename)
@@ -122,8 +128,7 @@ def convert_bin_to_safetensor_file(
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
model_name_or_path = model_config.model
@@ -131,26 +136,29 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, model_config.download_dir):
hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm)
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm,
)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
f
for f in config_files
if any(f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(
f"Cannot find the config file for {model_config.quantization}")
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
if len(quant_config_files) > 1:
raise ValueError(
f"Found multiple config files for {model_config.quantization}: "
f"{quant_config_files}")
f"{quant_config_files}"
)
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
@@ -166,8 +174,7 @@ def prepare_hf_model_weights(
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) \
and load_format != "tensorizer"
is_local = os.path.isdir(model_name_or_path) and load_format != "tensorizer"
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
@@ -203,11 +210,13 @@ def prepare_hf_model_weights(
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm,
revision=revision)
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm,
revision=revision,
)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
@@ -228,16 +237,14 @@ def prepare_hf_model_weights(
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
]
if load_format == "tensorizer":
return hf_folder, hf_weights_files, use_safetensors
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
@@ -254,7 +261,8 @@ def hf_model_weights_iterator(
cache_dir=cache_dir,
load_format=load_format,
fall_back_to_pt=fall_back_to_pt,
revision=revision)
revision=revision,
)
if load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
@@ -289,22 +297,25 @@ def hf_model_weights_iterator(
param = np.load(f)
yield name, torch.from_numpy(param)
elif load_format == "tensorizer":
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
open_stream,
tensorizer_warning)
from vllm.model_executor.tensorizer_loader import (
TensorDeserializer,
open_stream,
tensorizer_warning,
)
tensorizer_args = load_format.params
tensorizer_warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models.")
"script for serializing vLLM models."
)
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state:
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
for name, param in state.items():
yield name, param
del state
@@ -324,8 +335,12 @@ def hf_model_weights_iterator(
def kv_cache_scales_loader(
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
@@ -343,8 +358,7 @@ def kv_cache_scales_loader(
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct,
context=context)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
@@ -357,9 +371,11 @@ def kv_cache_scales_loader(
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading.")
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading."
)
return []
@@ -378,8 +394,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
return x
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
@@ -399,4 +414,4 @@ def initialize_dummy_weights(
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
param.data.uniform_(low, high)
param.data.uniform_(low, high)

View File

@@ -2,13 +2,16 @@
import base64
import json
import os
import sys
import threading
import traceback
import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
import numpy as np
import requests
@@ -110,6 +113,74 @@ def encode_image_base64(image_path):
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def encode_frame(frame):
import cv2 # pip install opencv-python-headless
from PIL import Image
# Convert the frame to RGB (OpenCV uses BGR by default)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert the frame to PIL Image to easily convert to bytes
im_pil = Image.fromarray(frame)
# Convert to bytes
buffered = BytesIO()
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
im_pil.save(buffered, format="PNG")
frame_bytes = buffered.getvalue()
# Return the bytes of the frame
return frame_bytes
def encode_video_base64(video_path, num_frames=16):
import cv2
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"target_frames: {num_frames}")
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for i in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
# Handle the case where the frame could not be read
# print(f"Warning: Could not read frame at index {i}.")
pass
cap.release()
# Safely select frames based on frame_indices, avoiding IndexError
frames = [frames[i] for i in frame_indices if i < len(frames)]
# If there are not enough frames, duplicate the last frame until we reach the target
while len(frames) < num_frames:
frames.append(frames[-1])
# Use ThreadPoolExecutor to process and encode frames in parallel
with ThreadPoolExecutor() as executor:
encoded_frames = list(executor.map(encode_frame, frames))
# encoded_frames = list(map(encode_frame, frames))
# Concatenate all frames bytes
video_bytes = b"".join(encoded_frames)
# Encode the concatenated bytes to base64
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
return video_base64
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
@@ -170,4 +241,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
if not ret_value:
raise RuntimeError()
return ret_value[0]
return ret_value[0]