support llava video (#426)
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -177,3 +177,7 @@ tmp*.txt
|
||||
# Plots
|
||||
*.png
|
||||
*.pdf
|
||||
|
||||
# personnal
|
||||
work_dirs/
|
||||
*.csv
|
||||
|
||||
@@ -410,4 +410,4 @@ https://github.com/sgl-project/sglang/issues/157
|
||||
}
|
||||
```
|
||||
|
||||
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
|
||||
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
|
||||
@@ -14,7 +14,7 @@ def single():
|
||||
state = image_qa.run(
|
||||
image_path="images/cat.jpeg",
|
||||
question="What is this?",
|
||||
max_new_tokens=64)
|
||||
max_new_tokens=128)
|
||||
print(state["answer"], "\n")
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def batch():
|
||||
{"image_path": "images/cat.jpeg", "question":"What is this?"},
|
||||
{"image_path": "images/dog.jpeg", "question":"What is this?"},
|
||||
],
|
||||
max_new_tokens=64,
|
||||
max_new_tokens=128,
|
||||
)
|
||||
for s in states:
|
||||
print(s["answer"], "\n")
|
||||
|
||||
208
examples/usage/llava_video/srt_example_llava_v.py
Normal file
208
examples/usage/llava_video/srt_example_llava_v.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Usage: python3 srt_example_llava.py
|
||||
"""
|
||||
|
||||
import sglang as sgl
|
||||
import os
|
||||
import csv
|
||||
import time
|
||||
import argparse
|
||||
|
||||
@sgl.function
|
||||
def video_qa(s, num_frames, video_path, question):
|
||||
s += sgl.user(sgl.video(video_path,num_frames) + question)
|
||||
s += sgl.assistant(sgl.gen("answer"))
|
||||
|
||||
|
||||
def single(path, num_frames=16):
|
||||
state = video_qa.run(
|
||||
num_frames=num_frames,
|
||||
video_path=path,
|
||||
question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes",
|
||||
temperature=0.0,
|
||||
max_new_tokens=1024,
|
||||
)
|
||||
print(state["answer"], "\n")
|
||||
|
||||
|
||||
|
||||
def split_into_chunks(lst, num_chunks):
|
||||
"""Split a list into a specified number of chunks."""
|
||||
# Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
|
||||
chunk_size = len(lst) // num_chunks
|
||||
|
||||
if chunk_size == 0:
|
||||
chunk_size = len(lst)
|
||||
# Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
|
||||
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
||||
# Ensure we have exactly num_chunks chunks, even if some are empty
|
||||
chunks.extend([[] for _ in range(num_chunks - len(chunks))])
|
||||
return chunks
|
||||
|
||||
|
||||
def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
|
||||
csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
|
||||
with open(csv_filename, 'w', newline='') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(['video_name', 'answer'])
|
||||
for video_path, state in zip(batch_video_files, states):
|
||||
video_name = os.path.basename(video_path)
|
||||
writer.writerow([video_name, state["answer"]])
|
||||
|
||||
def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
|
||||
final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
|
||||
with open(final_csv_filename, 'w', newline='') as final_csvfile:
|
||||
writer = csv.writer(final_csvfile)
|
||||
writer.writerow(['video_name', 'answer'])
|
||||
for batch_idx in range(num_batches):
|
||||
batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
|
||||
with open(batch_csv_filename, 'r') as batch_csvfile:
|
||||
reader = csv.reader(batch_csvfile)
|
||||
next(reader) # Skip header row
|
||||
for row in reader:
|
||||
writer.writerow(row)
|
||||
os.remove(batch_csv_filename)
|
||||
|
||||
def find_video_files(video_dir):
|
||||
# Check if the video_dir is actually a file
|
||||
if os.path.isfile(video_dir):
|
||||
# If it's a file, return it as a single-element list
|
||||
return [video_dir]
|
||||
|
||||
# Original logic to find video files in a directory
|
||||
video_files = []
|
||||
for root, dirs, files in os.walk(video_dir):
|
||||
for file in files:
|
||||
if file.endswith(('.mp4', '.avi', '.mov')):
|
||||
video_files.append(os.path.join(root, file))
|
||||
return video_files
|
||||
|
||||
def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
|
||||
video_files = find_video_files(video_dir)
|
||||
chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
|
||||
num_batches = 0
|
||||
|
||||
for i in range(0, len(chunked_video_files), batch_size):
|
||||
batch_video_files = chunked_video_files[i:i + batch_size]
|
||||
print(f"Processing batch of {len(batch_video_files)} video(s)...")
|
||||
|
||||
if not batch_video_files:
|
||||
print("No video files found in the specified directory.")
|
||||
return
|
||||
|
||||
batch_input = [
|
||||
{
|
||||
"num_frames": num_frames,
|
||||
"video_path": video_path,
|
||||
"question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
|
||||
} for video_path in batch_video_files
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
|
||||
total_time = time.time() - start_time
|
||||
average_time = total_time / len(batch_video_files)
|
||||
print(f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds")
|
||||
|
||||
save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
|
||||
num_batches += 1
|
||||
|
||||
compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Create the parser
|
||||
parser = argparse.ArgumentParser(description='Run video processing with specified port.')
|
||||
|
||||
# Add an argument for the port
|
||||
parser.add_argument('--port', type=int, default=30000, help='The master port for distributed serving.')
|
||||
parser.add_argument('--chunk-idx', type=int, default=0, help='The index of the chunk to process.')
|
||||
parser.add_argument('--num-chunks', type=int, default=8, help='The number of chunks to process.')
|
||||
parser.add_argument('--save-dir', type=str, default="./work_dirs/llava_video", help='The directory to save the processed video files.')
|
||||
parser.add_argument('--video-dir', type=str, default="./videos/Q98Z4OTh8RwmDonc.mp4", help='The directory or path for the processed video files.')
|
||||
parser.add_argument('--model-path', type=str, default="lmms-lab/LLaVA-NeXT-Video-7B", help='The model path for the video processing.')
|
||||
parser.add_argument('--num-frames', type=int, default=16, help='The number of frames to process in each video.' )
|
||||
parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
cur_port = args.port
|
||||
|
||||
cur_chunk = args.chunk_idx
|
||||
|
||||
num_chunks = args.num_chunks
|
||||
|
||||
num_frames = args.num_frames
|
||||
|
||||
if "34b" in args.model_path.lower():
|
||||
tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer"
|
||||
elif "7b" in args.model_path.lower():
|
||||
tokenizer_path = "llava-hf/llava-1.5-7b-hf"
|
||||
else:
|
||||
print("Invalid model path. Please specify a valid model path.")
|
||||
exit()
|
||||
|
||||
model_overide_args = {}
|
||||
|
||||
model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
|
||||
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||
model_overide_args["num_frames"] = args.num_frames
|
||||
model_overide_args["model_type"] = "llava"
|
||||
|
||||
if "34b" in args.model_path.lower():
|
||||
model_overide_args["image_token_index"] = 64002
|
||||
|
||||
|
||||
if args.num_frames == 32:
|
||||
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
||||
model_overide_args["max_sequence_length"] = 4096 * 2
|
||||
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
||||
elif args.num_frames < 32:
|
||||
pass
|
||||
else:
|
||||
print("The maximum number of frames to process is 32. Please specify a valid number of frames.")
|
||||
exit()
|
||||
|
||||
|
||||
runtime = sgl.Runtime(
|
||||
model_path=args.model_path, #"liuhaotian/llava-v1.6-vicuna-7b",
|
||||
tokenizer_path=tokenizer_path,
|
||||
port=cur_port,
|
||||
additional_ports=[cur_port+1,cur_port+2,cur_port+3,cur_port+4],
|
||||
model_overide_args=model_overide_args,
|
||||
tp_size=1
|
||||
)
|
||||
sgl.set_default_backend(runtime)
|
||||
print(f"chat template: {runtime.endpoint.chat_template.name}")
|
||||
|
||||
|
||||
# Run a single request
|
||||
# try:
|
||||
print("\n========== single ==========\n")
|
||||
root = args.video_dir
|
||||
if os.path.isfile(root):
|
||||
video_files = [root]
|
||||
else:
|
||||
video_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.mp4', '.avi', '.mov'))] # Add more extensions if needed
|
||||
start_time = time.time() # Start time for processing a single video
|
||||
for cur_video in video_files[:1]:
|
||||
print(cur_video)
|
||||
single(cur_video, num_frames)
|
||||
end_time = time.time() # End time for processing a single video
|
||||
total_time = end_time - start_time
|
||||
average_time = total_time / len(video_files) # Calculate the average processing time
|
||||
print(f"Average processing time per video: {average_time:.2f} seconds")
|
||||
runtime.shutdown()
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
runtime.shutdown()
|
||||
|
||||
|
||||
# # # Run a batch of requests
|
||||
# print("\n========== batch ==========\n")
|
||||
# if not os.path.exists(args.save_dir):
|
||||
# os.makedirs(args.save_dir)
|
||||
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
|
||||
# runtime.shutdown()
|
||||
130
examples/usage/llava_video/srt_example_llava_v.sh
Normal file
130
examples/usage/llava_video/srt_example_llava_v.sh
Normal file
@@ -0,0 +1,130 @@
|
||||
#!/bin/bash
|
||||
|
||||
##### USAGE #####
|
||||
# - First node:
|
||||
# ```sh
|
||||
# bash examples/quick_start/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
||||
# ```
|
||||
# - Second node:
|
||||
# ```sh
|
||||
# bash examples/quick_start/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
||||
# ```
|
||||
# - The K node:
|
||||
# ```sh
|
||||
# bash examples/quick_start/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
||||
# ```
|
||||
|
||||
|
||||
# Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details.
|
||||
CURRENT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
|
||||
echo ${CURRENT_ROOT}
|
||||
|
||||
cd ${CURRENT_ROOT}
|
||||
|
||||
export PYTHONWARNINGS=ignore
|
||||
|
||||
START_TIME=$(date +%s) # Capture start time
|
||||
|
||||
NUM_NODES=$1
|
||||
|
||||
CUR_NODES_IDX=$2
|
||||
|
||||
VIDEO_DIR=$3
|
||||
|
||||
MODEL_PATH=$4
|
||||
|
||||
NUM_FRAMES=$5
|
||||
|
||||
|
||||
# FRAME_FORMAT=$6
|
||||
|
||||
# FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]')
|
||||
|
||||
# # Check if FRAME_FORMAT is either JPEG or PNG
|
||||
# if [[ "$FRAME_FORMAT" != "JPEG" && "$FRAME_FORMAT" != "PNG" ]]; then
|
||||
# echo "Error: FRAME_FORMAT must be either JPEG or PNG."
|
||||
# exit 1
|
||||
# fi
|
||||
|
||||
# export TARGET_FRAMES=$TARGET_FRAMES
|
||||
|
||||
echo "Each video you will sample $NUM_FRAMES frames"
|
||||
|
||||
# export FRAME_FORMAT=$FRAME_FORMAT
|
||||
|
||||
# echo "The frame format is $FRAME_FORMAT"
|
||||
|
||||
# Assuming GPULIST is a bash array containing your GPUs
|
||||
GPULIST=(0 1 2 3 4 5 6 7)
|
||||
LOCAL_CHUNKS=${#GPULIST[@]}
|
||||
|
||||
echo "Number of GPUs in GPULIST: $LOCAL_CHUNKS"
|
||||
|
||||
ALL_CHUNKS=$((NUM_NODES * LOCAL_CHUNKS))
|
||||
|
||||
# Calculate GPUs per chunk
|
||||
GPUS_PER_CHUNK=8
|
||||
|
||||
echo $GPUS_PER_CHUNK
|
||||
|
||||
for IDX in $(seq 1 $LOCAL_CHUNKS); do
|
||||
(
|
||||
START=$(((IDX-1) * GPUS_PER_CHUNK))
|
||||
LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
|
||||
|
||||
CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
|
||||
|
||||
# Convert the chunk GPUs array to a comma-separated string
|
||||
CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
|
||||
|
||||
LOCAL_IDX=$((CUR_NODES_IDX * LOCAL_CHUNKS + IDX))
|
||||
|
||||
echo "Chunk $(($LOCAL_IDX - 1)) will run on GPUs $CHUNK_GPUS_STR"
|
||||
|
||||
# Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk.
|
||||
PORT=$((10000 + RANDOM % 55536))
|
||||
|
||||
MAX_RETRIES=10
|
||||
RETRY_COUNT=0
|
||||
COMMAND_STATUS=1 # Initialize as failed
|
||||
|
||||
while [ $RETRY_COUNT -lt $MAX_RETRIES ] && [ $COMMAND_STATUS -ne 0 ]; do
|
||||
echo "Running chunk $(($LOCAL_IDX - 1)) on GPUs $CHUNK_GPUS_STR with port $PORT. Attempt $(($RETRY_COUNT + 1))"
|
||||
|
||||
#!/bin/bash
|
||||
CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 examples/usage/llava_video/srt_example_llava_v.py \
|
||||
--port $PORT \
|
||||
--num-chunks $ALL_CHUNKS \
|
||||
--chunk-idx $(($LOCAL_IDX - 1)) \
|
||||
--save-dir work_dirs/llava_next_video_inference_results \
|
||||
--video-dir $VIDEO_DIR \
|
||||
--model-path $MODEL_PATH \
|
||||
--num-frames $NUM_FRAMES #&
|
||||
|
||||
wait $! # Wait for the process to finish and capture its exit status
|
||||
COMMAND_STATUS=$?
|
||||
|
||||
if [ $COMMAND_STATUS -ne 0 ]; then
|
||||
echo "Execution failed for chunk $(($LOCAL_IDX - 1)), attempt $(($RETRY_COUNT + 1)). Retrying..."
|
||||
RETRY_COUNT=$(($RETRY_COUNT + 1))
|
||||
sleep 180 # Wait a bit before retrying
|
||||
else
|
||||
echo "Execution succeeded for chunk $(($LOCAL_IDX - 1))."
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $COMMAND_STATUS -ne 0 ]; then
|
||||
echo "Execution failed for chunk $(($LOCAL_IDX - 1)) after $MAX_RETRIES attempts."
|
||||
fi
|
||||
) #&
|
||||
sleep 2 # Slight delay to stagger the start times
|
||||
done
|
||||
|
||||
wait
|
||||
|
||||
cat work_dirs/llava_next_video_inference_results/final_results_chunk_*.csv > work_dirs/llava_next_video_inference_results/final_results_node_${CUR_NODES_IDX}.csv
|
||||
|
||||
END_TIME=$(date +%s) # Capture end time
|
||||
ELAPSED_TIME=$(($END_TIME - $START_TIME))
|
||||
echo "Total execution time: $ELAPSED_TIME seconds."
|
||||
BIN
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
Normal file
BIN
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
Normal file
Binary file not shown.
@@ -20,7 +20,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
"zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "outlines>=0.0.27", "packaging"]
|
||||
"zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"]
|
||||
openai = ["openai>=1.0", "numpy", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0", "numpy"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
@@ -33,4 +33,4 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
|
||||
[tool.wheel]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
@@ -19,6 +19,7 @@ from sglang.api import (
|
||||
user,
|
||||
user_begin,
|
||||
user_end,
|
||||
video,
|
||||
)
|
||||
|
||||
# SGL Backends
|
||||
@@ -46,6 +47,7 @@ __all__ = [
|
||||
"gen_int",
|
||||
"gen_string",
|
||||
"image",
|
||||
"video",
|
||||
"select",
|
||||
"system",
|
||||
"user",
|
||||
|
||||
@@ -15,6 +15,7 @@ from sglang.lang.ir import (
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVideo,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,6 +152,10 @@ def image(expr: SglExpr):
|
||||
return SglImage(expr)
|
||||
|
||||
|
||||
def video(path: str, num_frames: int):
|
||||
return SglVideo(path, num_frames)
|
||||
|
||||
|
||||
def select(
|
||||
name: Optional[str] = None,
|
||||
choices: List[str] = None,
|
||||
|
||||
@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava-v1.5" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava-next-video-7b" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
# import pdb;pdb.set_trace()
|
||||
model_path = model_path.lower()
|
||||
if "tinyllama" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
if "qwen" in model_path and "chat" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
if "llava-v1.6-34b" in model_path:
|
||||
if (
|
||||
"llava-v1.6-34b" in model_path
|
||||
or "llava-v1.6-yi-34b" in model_path
|
||||
or "llava-next-video-34b" in model_path
|
||||
):
|
||||
return get_chat_template("chatml-llava")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_yi(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "yi" in model_path:
|
||||
if "yi" in model_path and "llava" not in model_path:
|
||||
return get_chat_template("yi")
|
||||
|
||||
|
||||
|
||||
@@ -28,8 +28,9 @@ from sglang.lang.ir import (
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
SglVideo,
|
||||
)
|
||||
from sglang.utils import encode_image_base64, get_exception_traceback
|
||||
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
|
||||
|
||||
|
||||
def run_internal(state, program, func_args, func_kwargs, sync):
|
||||
@@ -361,6 +362,8 @@ class StreamExecutor:
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglImage):
|
||||
self._execute_image(other)
|
||||
elif isinstance(other, SglVideo):
|
||||
self._execute_video(other)
|
||||
elif isinstance(other, SglVariable):
|
||||
self._execute_variable(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
@@ -397,6 +400,16 @@ class StreamExecutor:
|
||||
self.cur_images.append((path, base64_data))
|
||||
self.text_ += self.chat_template.image_token
|
||||
|
||||
def _execute_video(self, expr: SglVideo):
|
||||
path = expr.path
|
||||
num_frames = expr.num_frames
|
||||
|
||||
base64_data = encode_video_base64(path, num_frames)
|
||||
|
||||
self.images_.append((path, base64_data))
|
||||
self.cur_images.append((path, base64_data))
|
||||
self.text_ += self.chat_template.image_token
|
||||
|
||||
# if global_config.eager_fill_image:
|
||||
# self.backend.fill_image(self)
|
||||
|
||||
|
||||
@@ -330,6 +330,15 @@ class SglImage(SglExpr):
|
||||
return f"SglImage({self.path})"
|
||||
|
||||
|
||||
class SglVideo(SglExpr):
|
||||
def __init__(self, path, num_frames):
|
||||
self.path = path
|
||||
self.num_frames = num_frames
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglVideo({self.path}, {self.num_frames})"
|
||||
|
||||
|
||||
class SglGen(SglExpr):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
|
||||
##################################
|
||||
|
||||
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||
assert (size >= 1)
|
||||
assert size >= 1
|
||||
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
|
||||
@@ -2,11 +2,10 @@ import argparse
|
||||
|
||||
from sglang.srt.server import ServerArgs, launch_server
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
launch_server(server_args, None)
|
||||
launch_server(server_args, None)
|
||||
|
||||
31
python/sglang/launch_server_llavavid.py
Normal file
31
python/sglang/launch_server_llavavid.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
|
||||
from sglang.srt.server import ServerArgs, launch_server
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
model_overide_args = {}
|
||||
|
||||
model_overide_args["mm_spatial_pool_stride"] = 2
|
||||
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||
model_overide_args["num_frames"] = 16
|
||||
model_overide_args["model_type"] = "llavavid"
|
||||
if model_overide_args["num_frames"] == 32:
|
||||
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
||||
model_overide_args["max_sequence_length"] = 4096 * 2
|
||||
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
||||
model_overide_args["model_max_length"] = 4096 * 2
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
if "34b" in args.model_path.lower():
|
||||
model_overide_args["image_token_index"] = 64002
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
|
||||
launch_server(server_args, pipe_writer, model_overide_args)
|
||||
@@ -30,10 +30,17 @@ def get_config_json(model_path: str):
|
||||
return config
|
||||
|
||||
|
||||
def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
|
||||
def get_config(
|
||||
model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
):
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision
|
||||
)
|
||||
if model_overide_args:
|
||||
config.update(model_overide_args)
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@@ -60,9 +60,7 @@ class RouterManager:
|
||||
|
||||
|
||||
def start_router_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
||||
):
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
@@ -70,7 +68,7 @@ def start_router_process(
|
||||
)
|
||||
|
||||
try:
|
||||
model_client = ModelRpcClient(server_args, port_args)
|
||||
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
|
||||
router = RouterManager(model_client, port_args)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
|
||||
@@ -4,12 +4,13 @@ import multiprocessing
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
|
||||
try:
|
||||
from vllm.logger import _default_handler as vllm_default_logger
|
||||
except ImportError:
|
||||
@@ -48,6 +49,7 @@ class ModelRpcServer:
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
):
|
||||
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
||||
|
||||
@@ -62,6 +64,7 @@ class ModelRpcServer:
|
||||
server_args.model_path,
|
||||
server_args.trust_remote_code,
|
||||
context_length=server_args.context_length,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
|
||||
# For model end global settings
|
||||
@@ -673,13 +676,15 @@ class ModelRpcService(rpyc.Service):
|
||||
|
||||
|
||||
class ModelRpcClient:
|
||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
||||
def __init__(
|
||||
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
|
||||
):
|
||||
tp_size = server_args.tp_size
|
||||
|
||||
if tp_size == 1:
|
||||
# Init model
|
||||
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
||||
0, server_args, port_args
|
||||
0, server_args, port_args, model_overide_args
|
||||
)
|
||||
|
||||
# Wrap functions
|
||||
@@ -700,7 +705,7 @@ class ModelRpcClient:
|
||||
# Init model
|
||||
def init_model(i):
|
||||
return self.remote_services[i].ModelRpcServer(
|
||||
i, server_args, port_args
|
||||
i, server_args, port_args, model_overide_args
|
||||
)
|
||||
|
||||
self.model_servers = executor.map(init_model, range(tp_size))
|
||||
@@ -723,7 +728,11 @@ def _init_service(port):
|
||||
t = ThreadedServer(
|
||||
ModelRpcService(),
|
||||
port=port,
|
||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 1800,
|
||||
},
|
||||
)
|
||||
t.start()
|
||||
|
||||
@@ -739,7 +748,11 @@ def start_model_process(port):
|
||||
con = rpyc.connect(
|
||||
"localhost",
|
||||
port,
|
||||
config={"allow_pickle": True, "sync_request_timeout": 1800},
|
||||
config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 1800,
|
||||
},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
|
||||
@@ -9,11 +9,11 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
@@ -143,7 +143,7 @@ class InputMetadata:
|
||||
self.kv_last_page_len,
|
||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
self.model_runner.model_config.head_dim
|
||||
self.model_runner.model_config.head_dim,
|
||||
]
|
||||
|
||||
self.prefill_wrapper.begin_forward(*args)
|
||||
|
||||
@@ -60,21 +60,29 @@ def get_pixel_values(
|
||||
):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image = load_image(image_data)
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size != None:
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
@@ -84,6 +92,7 @@ class TokenizerManager:
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
model_overide_args: dict = None,
|
||||
):
|
||||
self.server_args = server_args
|
||||
|
||||
@@ -96,7 +105,9 @@ class TokenizerManager:
|
||||
|
||||
self.model_path = server_args.model_path
|
||||
self.hf_config = get_config(
|
||||
self.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
self.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
|
||||
@@ -10,12 +10,16 @@ class ModelConfig:
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.revision = revision
|
||||
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
||||
|
||||
if model_overide_args is not None:
|
||||
self.hf_config.update(model_overide_args)
|
||||
|
||||
if context_length is not None:
|
||||
self.context_len = context_length
|
||||
else:
|
||||
|
||||
@@ -27,29 +27,25 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
@torch.compile
|
||||
|
||||
@@ -5,37 +5,31 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.models.dbrx_config import DbrxConfig
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class DbrxRouter(nn.Module):
|
||||
@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config)
|
||||
self.norm_attn_norm = DbrxFusedNormAttention(
|
||||
config, layer_id, quant_config=quant_config
|
||||
)
|
||||
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)]
|
||||
[
|
||||
DbrxBlock(config, i, quant_config=quant_config)
|
||||
for i in range(config.n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||
for module in self.modules():
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act_fn = GeluAndMul()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
|
||||
@@ -7,12 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class LlavaLlamaForCausalLM(nn.Module):
|
||||
|
||||
307
python/sglang/srt/models/llavavid.py
Normal file
307
python/sglang/srt/models/llavavid.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class LlavaVidForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
|
||||
self.resampler = nn.AvgPool2d(
|
||||
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
|
||||
)
|
||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||
self.num_frames = getattr(self.config, "num_frames", 16)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||
new_image_feature_len = self.image_feature_len
|
||||
# now only support spatial_unpad + anyres
|
||||
# if self.mm_patch_merge_type.startswith("spatial"):
|
||||
# height = width = self.num_patches_per_side
|
||||
# if pt_shape[0] > 1:
|
||||
# if self.image_aspect_ratio == "anyres":
|
||||
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
# image_size,
|
||||
# self.image_grid_pinpoints,
|
||||
# self.vision_tower.config.image_size,
|
||||
# )
|
||||
# if "unpad" in self.mm_patch_merge_type:
|
||||
# h = num_patch_height * height
|
||||
# w = num_patch_width * width
|
||||
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||
# new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
# print(input_ids)
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
new_input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, offset
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||
|
||||
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
||||
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||
)
|
||||
|
||||
height = width = self.num_patches_per_side
|
||||
num_of_frames = selected_image_feature.shape[0]
|
||||
selected_image_feature = selected_image_feature.view(
|
||||
num_of_frames, height, width, -1
|
||||
)
|
||||
selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
|
||||
selected_image_feature = (
|
||||
self.resampler(selected_image_feature)
|
||||
.flatten(2)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
bs = input_metadata.batch_size
|
||||
|
||||
# Embed text input
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Embed vision input
|
||||
need_vision = (
|
||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
# FIXME: We need to substract the length of the system prompt
|
||||
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||
need_vision = need_vision & has_pixel
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||
|
||||
########## Encode Image ########
|
||||
|
||||
if pixel_values[0].ndim == 4:
|
||||
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
||||
np.concatenate(pixel_values, axis=0)
|
||||
# ndim=4
|
||||
concat_images = torch.tensor(
|
||||
np.concatenate(pixel_values, axis=0),
|
||||
device=self.vision_tower.device,
|
||||
)
|
||||
# image_features = self.encode_images(concat_images)
|
||||
# split_sizes = [image.shape[0] for image in pixel_values]
|
||||
# image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
image_features = self.encode_images(
|
||||
concat_images
|
||||
) # , prompts)#, image_counts, long_video=long_video)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# hd image_features: BS, num_patch, 576, 4096
|
||||
else:
|
||||
# normal pixel: BS, C=3, H=336, W=336
|
||||
pixel_values = torch.tensor(
|
||||
np.array(pixel_values), device=self.vision_tower.device
|
||||
)
|
||||
image_features = self.encode_images(pixel_values)
|
||||
# image_features: BS, 576, 4096
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
new_image_features.append(image_feature.flatten(0, 1))
|
||||
image_features = new_image_features
|
||||
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||
# Fill in the placeholder for the image
|
||||
try:
|
||||
input_embeds[
|
||||
start_idx
|
||||
+ image_offsets[i] : start_idx
|
||||
+ image_offsets[i]
|
||||
+ pad_len
|
||||
] = image_features[pt]
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in llava image encoding: {e}")
|
||||
print(input_embeds.shape)
|
||||
print(start_idx, image_offsets[i])
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.language_model(input_ids, positions, input_metadata)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
vision_path = self.config.mm_vision_tower
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
self.vision_tower.eval()
|
||||
|
||||
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
||||
self.image_size = self.vision_tower.config.image_size
|
||||
self.patch_size = self.vision_tower.config.patch_size
|
||||
|
||||
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
||||
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
||||
|
||||
print(f"target_frames: {self.num_frames}")
|
||||
self.image_feature_len = self.num_frames * int(
|
||||
(self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
|
||||
)
|
||||
if self.vision_feature_select_strategy == "patch":
|
||||
pass
|
||||
elif self.vision_feature_select_strategy == "cls_patch":
|
||||
self.image_feature_len += 1
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
|
||||
# load mm_projector
|
||||
projector_weights = {
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
else:
|
||||
print(f"Warning: {name} not found in the model")
|
||||
continue
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
@property
|
||||
def num_patches_per_side(self):
|
||||
return self.image_size // self.patch_size
|
||||
|
||||
|
||||
first_call = True
|
||||
|
||||
|
||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||
global first_call
|
||||
if first_call:
|
||||
self.patch_embedding.cpu().float()
|
||||
first_call = False
|
||||
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
def monkey_path_clip_vision_embed_forward():
|
||||
import transformers
|
||||
|
||||
setattr(
|
||||
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||
"forward",
|
||||
clip_vision_embed_forward,
|
||||
)
|
||||
|
||||
|
||||
EntryClass = LlavaVidForCausalLM
|
||||
@@ -8,34 +8,28 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
|
||||
|
||||
|
||||
class QWenBlock(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
|
||||
|
||||
|
||||
class QWenModel(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
|
||||
|
||||
|
||||
class QWenLMHeadModel(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = QWenModel(config, quant_config=quant_config)
|
||||
@@ -276,4 +283,4 @@ class QWenLMHeadModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = QWenLMHeadModel
|
||||
EntryClass = QWenLMHeadModel
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
Qwen2Config = None
|
||||
|
||||
@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
|
||||
@@ -7,35 +7,31 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -48,7 +44,10 @@ class StablelmMLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
|
||||
|
||||
class StableLMEpochModel(nn.Module):
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
|
||||
@@ -6,16 +6,13 @@ from typing import List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from sglang.srt.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.models.llava import (
|
||||
LlavaLlamaForCausalLM,
|
||||
clip_vision_embed_forward,
|
||||
monkey_path_clip_vision_embed_forward,
|
||||
)
|
||||
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||
|
||||
|
||||
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
|
||||
@@ -107,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request):
|
||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
||||
global tokenizer_manager
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -140,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
)
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
|
||||
proc_router = mp.Process(
|
||||
target=start_router_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
pipe_router_writer,
|
||||
),
|
||||
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
||||
)
|
||||
proc_router.start()
|
||||
proc_detoken = mp.Process(
|
||||
@@ -170,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_router.kill()
|
||||
proc_detoken.kill()
|
||||
print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
|
||||
print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
|
||||
print(
|
||||
f"Initialization failed. router_init_state: {router_init_state}", flush=True
|
||||
)
|
||||
print(
|
||||
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
||||
flush=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
@@ -189,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
success = True # Set flag to True if request succeeds
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
pass
|
||||
@@ -205,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
except Exception as e:
|
||||
@@ -235,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
class Runtime:
|
||||
def __init__(
|
||||
self,
|
||||
log_evel="error",
|
||||
log_evel: str = "error",
|
||||
model_overide_args: Optional[dict] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -244,7 +247,10 @@ class Runtime:
|
||||
|
||||
# Pre-allocate ports
|
||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||
self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
|
||||
self.server_args.port,
|
||||
self.server_args.additional_ports,
|
||||
self.server_args.tp_size,
|
||||
)
|
||||
|
||||
self.url = self.server_args.url()
|
||||
self.generate_url = (
|
||||
@@ -253,7 +259,10 @@ class Runtime:
|
||||
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
|
||||
proc = mp.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, pipe_writer, model_overide_args),
|
||||
)
|
||||
proc.start()
|
||||
pipe_writer.close()
|
||||
self.pid = proc.pid
|
||||
@@ -265,7 +274,9 @@ class Runtime:
|
||||
|
||||
if init_state != "init ok":
|
||||
self.shutdown()
|
||||
raise RuntimeError("Initialization failed. Please see the error messages above.")
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
)
|
||||
|
||||
self.endpoint = RuntimeEndpoint(self.url)
|
||||
|
||||
@@ -317,4 +328,4 @@ class Runtime:
|
||||
pos += len(cur)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
self.shutdown()
|
||||
|
||||
@@ -80,10 +80,12 @@ class ServerArgs:
|
||||
default=ServerArgs.tokenizer_path,
|
||||
help="The path of the tokenizer.",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default=ServerArgs.host,
|
||||
help="The host of the server.")
|
||||
parser.add_argument("--port", type=int, default=ServerArgs.port,
|
||||
help="The port of the server.")
|
||||
parser.add_argument(
|
||||
"--host", type=str, default=ServerArgs.host, help="The host of the server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=ServerArgs.port, help="The port of the server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional-ports",
|
||||
type=int,
|
||||
@@ -261,4 +263,4 @@ class PortArgs:
|
||||
router_port: int
|
||||
detokenizer_port: int
|
||||
nccl_port: int
|
||||
model_rpc_ports: List[int]
|
||||
model_rpc_ports: List[int]
|
||||
|
||||
@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
|
||||
continue
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
try:
|
||||
s.bind(("", port))
|
||||
s.listen(1) # Attempt to listen on the port
|
||||
port_list.append(port)
|
||||
except socket.error:
|
||||
pass
|
||||
pass # If any error occurs, this port is not usable
|
||||
|
||||
if len(port_list) == num:
|
||||
return port_list
|
||||
@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
|
||||
|
||||
|
||||
def is_multimodal_model(model):
|
||||
if isinstance(model, str):
|
||||
return "llava" in model or "yi-vl" in model
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
if isinstance(model, str):
|
||||
model = model.lower()
|
||||
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
||||
|
||||
if isinstance(model, ModelConfig):
|
||||
model_path = model.path.lower()
|
||||
return "llava" in model_path or "yi-vl" in model_path
|
||||
raise Exception("unrecognized type")
|
||||
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
||||
|
||||
raise ValueError("unrecognized type")
|
||||
|
||||
|
||||
def decode_video_base64(video_base64):
|
||||
from PIL import Image
|
||||
|
||||
# Decode the base64 string
|
||||
video_bytes = base64.b64decode(video_base64)
|
||||
|
||||
# Placeholder for the start indices of each PNG image
|
||||
img_starts = []
|
||||
|
||||
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
|
||||
|
||||
assert frame_format in [
|
||||
"PNG",
|
||||
"JPEG",
|
||||
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
|
||||
|
||||
if frame_format == "PNG":
|
||||
# Find each PNG start signature to isolate images
|
||||
i = 0
|
||||
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
|
||||
# Check if we found the start of a PNG file
|
||||
if (
|
||||
video_bytes[i] == 0x89
|
||||
and video_bytes[i + 1] == 0x50
|
||||
and video_bytes[i + 2] == 0x4E
|
||||
and video_bytes[i + 3] == 0x47
|
||||
and video_bytes[i + 4] == 0x0D
|
||||
and video_bytes[i + 5] == 0x0A
|
||||
and video_bytes[i + 6] == 0x1A
|
||||
and video_bytes[i + 7] == 0x0A
|
||||
):
|
||||
img_starts.append(i)
|
||||
i += 8 # Skip the PNG signature
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
# Find each JPEG start (0xFFD8) to isolate images
|
||||
i = 0
|
||||
while (
|
||||
i < len(video_bytes) - 1
|
||||
): # Adjusted for the length of the JPEG SOI signature
|
||||
# Check if we found the start of a JPEG file
|
||||
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
|
||||
img_starts.append(i)
|
||||
# Move to the next byte to continue searching for the next image start
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
|
||||
frames = []
|
||||
for start_idx in img_starts:
|
||||
# Assuming each image is back-to-back, the end of one image is the start of another
|
||||
# The last image goes until the end of the byte string
|
||||
end_idx = (
|
||||
img_starts[img_starts.index(start_idx) + 1]
|
||||
if img_starts.index(start_idx) + 1 < len(img_starts)
|
||||
else len(video_bytes)
|
||||
)
|
||||
img_bytes = video_bytes[start_idx:end_idx]
|
||||
|
||||
# Convert bytes to a PIL Image
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
|
||||
# Convert PIL Image to a NumPy array
|
||||
frame = np.array(img)
|
||||
|
||||
# Append the frame to the list of frames
|
||||
frames.append(frame)
|
||||
|
||||
# Ensure there's at least one frame to avoid errors with np.stack
|
||||
if frames:
|
||||
return np.stack(frames, axis=0), img.size
|
||||
else:
|
||||
return np.array([]), (
|
||||
0,
|
||||
0,
|
||||
) # Return an empty array and size tuple if no frames were found
|
||||
|
||||
|
||||
def load_image(image_file):
|
||||
from PIL import Image
|
||||
|
||||
image = None
|
||||
image = image_size = None
|
||||
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||
@@ -289,10 +373,13 @@ def load_image(image_file):
|
||||
elif image_file.startswith("data:"):
|
||||
image_file = image_file.split(",")[1]
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
elif image_file.startswith("video:"):
|
||||
image_file = image_file.replace("video:", "")
|
||||
image, image_size = decode_video_base64(image_file)
|
||||
else:
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
|
||||
return image
|
||||
return image, image_size
|
||||
|
||||
|
||||
def assert_pkg_version(pkg: str, min_version: str):
|
||||
@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
|
||||
f"is less than the minimum required version {min_version}"
|
||||
)
|
||||
except PackageNotFoundError:
|
||||
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
|
||||
raise Exception(
|
||||
f"{pkg} with minimum required version {min_version} is not installed"
|
||||
)
|
||||
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
@@ -19,11 +19,12 @@ import torch
|
||||
from huggingface_hub import HfFileSystem, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
get_quantization_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -32,17 +33,21 @@ logger = init_logger(__name__)
|
||||
# can share the same lock without error.
|
||||
# lock files in the temp directory will be automatically deleted when the
|
||||
# system reboots, so users will not complain about annoying lock files
|
||||
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
|
||||
'TEMP') or os.environ.get('TMP') or "/tmp/"
|
||||
temp_dir = (
|
||||
os.environ.get("TMPDIR")
|
||||
or os.environ.get("TEMP")
|
||||
or os.environ.get("TMP")
|
||||
or "/tmp/"
|
||||
)
|
||||
|
||||
|
||||
def enable_hf_transfer():
|
||||
"""automatically activates hf_transfer
|
||||
"""
|
||||
"""automatically activates hf_transfer"""
|
||||
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
||||
try:
|
||||
# enable hf hub transfer if available
|
||||
import hf_transfer # type: ignore # noqa
|
||||
|
||||
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -65,8 +70,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||
# add hash to avoid conflict with old users' lock files
|
||||
lock_file_name = hash_name + model_name + ".lock"
|
||||
# mode 0o666 is required for the filelock to be shared across users
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
|
||||
mode=0o666)
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
|
||||
return lock
|
||||
|
||||
|
||||
@@ -104,10 +108,12 @@ def convert_bin_to_safetensor_file(
|
||||
sf_size = os.stat(sf_filename).st_size
|
||||
pt_size = os.stat(pt_filename).st_size
|
||||
if (sf_size - pt_size) / pt_size > 0.01:
|
||||
raise RuntimeError(f"""The file size different is more than 1%:
|
||||
raise RuntimeError(
|
||||
f"""The file size different is more than 1%:
|
||||
- {sf_filename}: {sf_size}
|
||||
- {pt_filename}: {pt_size}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
@@ -122,8 +128,7 @@ def convert_bin_to_safetensor_file(
|
||||
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
model_name_or_path = model_config.model
|
||||
@@ -131,26 +136,29 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, model_config.download_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=model_config.download_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=model_config.download_dir,
|
||||
tqdm_class=Disabledtqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||
f
|
||||
for f in config_files
|
||||
if any(f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(
|
||||
f"Cannot find the config file for {model_config.quantization}")
|
||||
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(
|
||||
f"Found multiple config files for {model_config.quantization}: "
|
||||
f"{quant_config_files}")
|
||||
f"{quant_config_files}"
|
||||
)
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file, "r") as f:
|
||||
@@ -166,8 +174,7 @@ def prepare_hf_model_weights(
|
||||
revision: Optional[str] = None,
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path) \
|
||||
and load_format != "tensorizer"
|
||||
is_local = os.path.isdir(model_name_or_path) and load_format != "tensorizer"
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "auto":
|
||||
@@ -203,11 +210,13 @@ def prepare_hf_model_weights(
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm,
|
||||
revision=revision)
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm,
|
||||
revision=revision,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
hf_weights_files: List[str] = []
|
||||
@@ -228,16 +237,14 @@ def prepare_hf_model_weights(
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
|
||||
if load_format == "tensorizer":
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
@@ -254,7 +261,8 @@ def hf_model_weights_iterator(
|
||||
cache_dir=cache_dir,
|
||||
load_format=load_format,
|
||||
fall_back_to_pt=fall_back_to_pt,
|
||||
revision=revision)
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
@@ -289,22 +297,25 @@ def hf_model_weights_iterator(
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
elif load_format == "tensorizer":
|
||||
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
|
||||
open_stream,
|
||||
tensorizer_warning)
|
||||
from vllm.model_executor.tensorizer_loader import (
|
||||
TensorDeserializer,
|
||||
open_stream,
|
||||
tensorizer_warning,
|
||||
)
|
||||
|
||||
tensorizer_args = load_format.params
|
||||
tensorizer_warning(
|
||||
"Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
"load times. See the examples/tensorize_vllm_model.py example "
|
||||
"script for serializing vLLM models.")
|
||||
"script for serializing vLLM models."
|
||||
)
|
||||
|
||||
deserializer_args = tensorizer_args.deserializer_params
|
||||
stream_params = tensorizer_args.stream_params
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||
with TensorDeserializer(stream, **deserializer_args,
|
||||
device="cpu") as state:
|
||||
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
||||
@@ -324,8 +335,12 @@ def hf_model_weights_iterator(
|
||||
|
||||
|
||||
def kv_cache_scales_loader(
|
||||
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
|
||||
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
|
||||
filename: str,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
num_hidden_layers: int,
|
||||
model_type: Optional[str],
|
||||
) -> Iterable[Tuple[int, float]]:
|
||||
"""
|
||||
A simple utility to read in KV cache scaling factors that have been
|
||||
previously serialized to disk. Used by the model to populate the appropriate
|
||||
@@ -343,8 +358,7 @@ def kv_cache_scales_loader(
|
||||
"tp_size": tp_size,
|
||||
}
|
||||
schema_dct = json.load(f)
|
||||
schema = QuantParamSchema.model_validate(schema_dct,
|
||||
context=context)
|
||||
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
||||
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
||||
return layer_scales_map.items()
|
||||
|
||||
@@ -357,9 +371,11 @@ def kv_cache_scales_loader(
|
||||
# This section is reached if and only if any of the excepts are hit
|
||||
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||
# which ultimately defaults to 1.0 scales
|
||||
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
|
||||
f"for all layers in TP rank {tp_rank} "
|
||||
"as an error occurred during loading.")
|
||||
logger.warning(
|
||||
"Defaulting to KV cache scaling factors = 1.0 "
|
||||
f"for all layers in TP rank {tp_rank} "
|
||||
"as an error occurred during loading."
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@@ -378,8 +394,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
def default_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
"""Default weight loader."""
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
@@ -399,4 +414,4 @@ def initialize_dummy_weights(
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
param.data.uniform_(low, high)
|
||||
param.data.uniform_(low, high)
|
||||
|
||||
@@ -2,13 +2,16 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from json import dumps
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
@@ -110,6 +113,74 @@ def encode_image_base64(image_path):
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def encode_frame(frame):
|
||||
import cv2 # pip install opencv-python-headless
|
||||
from PIL import Image
|
||||
|
||||
# Convert the frame to RGB (OpenCV uses BGR by default)
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Convert the frame to PIL Image to easily convert to bytes
|
||||
im_pil = Image.fromarray(frame)
|
||||
|
||||
# Convert to bytes
|
||||
buffered = BytesIO()
|
||||
|
||||
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
|
||||
|
||||
im_pil.save(buffered, format="PNG")
|
||||
|
||||
frame_bytes = buffered.getvalue()
|
||||
|
||||
# Return the bytes of the frame
|
||||
return frame_bytes
|
||||
|
||||
|
||||
def encode_video_base64(video_path, num_frames=16):
|
||||
import cv2
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise IOError(f"Could not open video file:{video_path}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
print(f"target_frames: {num_frames}")
|
||||
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
|
||||
frames = []
|
||||
for i in range(total_frames):
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
frames.append(frame)
|
||||
else:
|
||||
# Handle the case where the frame could not be read
|
||||
# print(f"Warning: Could not read frame at index {i}.")
|
||||
pass
|
||||
|
||||
cap.release()
|
||||
|
||||
# Safely select frames based on frame_indices, avoiding IndexError
|
||||
frames = [frames[i] for i in frame_indices if i < len(frames)]
|
||||
|
||||
# If there are not enough frames, duplicate the last frame until we reach the target
|
||||
while len(frames) < num_frames:
|
||||
frames.append(frames[-1])
|
||||
|
||||
# Use ThreadPoolExecutor to process and encode frames in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
encoded_frames = list(executor.map(encode_frame, frames))
|
||||
|
||||
# encoded_frames = list(map(encode_frame, frames))
|
||||
|
||||
# Concatenate all frames bytes
|
||||
video_bytes = b"".join(encoded_frames)
|
||||
|
||||
# Encode the concatenated bytes to base64
|
||||
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
|
||||
|
||||
return video_base64
|
||||
|
||||
|
||||
def _is_chinese_char(cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
@@ -170,4 +241,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
||||
if not ret_value:
|
||||
raise RuntimeError()
|
||||
|
||||
return ret_value[0]
|
||||
return ret_value[0]
|
||||
|
||||
Reference in New Issue
Block a user