support llava video (#426)
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -177,3 +177,7 @@ tmp*.txt
|
|||||||
# Plots
|
# Plots
|
||||||
*.png
|
*.png
|
||||||
*.pdf
|
*.pdf
|
||||||
|
|
||||||
|
# personnal
|
||||||
|
work_dirs/
|
||||||
|
*.csv
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ def single():
|
|||||||
state = image_qa.run(
|
state = image_qa.run(
|
||||||
image_path="images/cat.jpeg",
|
image_path="images/cat.jpeg",
|
||||||
question="What is this?",
|
question="What is this?",
|
||||||
max_new_tokens=64)
|
max_new_tokens=128)
|
||||||
print(state["answer"], "\n")
|
print(state["answer"], "\n")
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ def batch():
|
|||||||
{"image_path": "images/cat.jpeg", "question":"What is this?"},
|
{"image_path": "images/cat.jpeg", "question":"What is this?"},
|
||||||
{"image_path": "images/dog.jpeg", "question":"What is this?"},
|
{"image_path": "images/dog.jpeg", "question":"What is this?"},
|
||||||
],
|
],
|
||||||
max_new_tokens=64,
|
max_new_tokens=128,
|
||||||
)
|
)
|
||||||
for s in states:
|
for s in states:
|
||||||
print(s["answer"], "\n")
|
print(s["answer"], "\n")
|
||||||
|
|||||||
208
examples/usage/llava_video/srt_example_llava_v.py
Normal file
208
examples/usage/llava_video/srt_example_llava_v.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""
|
||||||
|
Usage: python3 srt_example_llava.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def video_qa(s, num_frames, video_path, question):
|
||||||
|
s += sgl.user(sgl.video(video_path,num_frames) + question)
|
||||||
|
s += sgl.assistant(sgl.gen("answer"))
|
||||||
|
|
||||||
|
|
||||||
|
def single(path, num_frames=16):
|
||||||
|
state = video_qa.run(
|
||||||
|
num_frames=num_frames,
|
||||||
|
video_path=path,
|
||||||
|
question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes",
|
||||||
|
temperature=0.0,
|
||||||
|
max_new_tokens=1024,
|
||||||
|
)
|
||||||
|
print(state["answer"], "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def split_into_chunks(lst, num_chunks):
|
||||||
|
"""Split a list into a specified number of chunks."""
|
||||||
|
# Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
|
||||||
|
chunk_size = len(lst) // num_chunks
|
||||||
|
|
||||||
|
if chunk_size == 0:
|
||||||
|
chunk_size = len(lst)
|
||||||
|
# Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
|
||||||
|
chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
||||||
|
# Ensure we have exactly num_chunks chunks, even if some are empty
|
||||||
|
chunks.extend([[] for _ in range(num_chunks - len(chunks))])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
|
||||||
|
csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
|
||||||
|
with open(csv_filename, 'w', newline='') as csvfile:
|
||||||
|
writer = csv.writer(csvfile)
|
||||||
|
writer.writerow(['video_name', 'answer'])
|
||||||
|
for video_path, state in zip(batch_video_files, states):
|
||||||
|
video_name = os.path.basename(video_path)
|
||||||
|
writer.writerow([video_name, state["answer"]])
|
||||||
|
|
||||||
|
def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
|
||||||
|
final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
|
||||||
|
with open(final_csv_filename, 'w', newline='') as final_csvfile:
|
||||||
|
writer = csv.writer(final_csvfile)
|
||||||
|
writer.writerow(['video_name', 'answer'])
|
||||||
|
for batch_idx in range(num_batches):
|
||||||
|
batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
|
||||||
|
with open(batch_csv_filename, 'r') as batch_csvfile:
|
||||||
|
reader = csv.reader(batch_csvfile)
|
||||||
|
next(reader) # Skip header row
|
||||||
|
for row in reader:
|
||||||
|
writer.writerow(row)
|
||||||
|
os.remove(batch_csv_filename)
|
||||||
|
|
||||||
|
def find_video_files(video_dir):
|
||||||
|
# Check if the video_dir is actually a file
|
||||||
|
if os.path.isfile(video_dir):
|
||||||
|
# If it's a file, return it as a single-element list
|
||||||
|
return [video_dir]
|
||||||
|
|
||||||
|
# Original logic to find video files in a directory
|
||||||
|
video_files = []
|
||||||
|
for root, dirs, files in os.walk(video_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(('.mp4', '.avi', '.mov')):
|
||||||
|
video_files.append(os.path.join(root, file))
|
||||||
|
return video_files
|
||||||
|
|
||||||
|
def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
|
||||||
|
video_files = find_video_files(video_dir)
|
||||||
|
chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
for i in range(0, len(chunked_video_files), batch_size):
|
||||||
|
batch_video_files = chunked_video_files[i:i + batch_size]
|
||||||
|
print(f"Processing batch of {len(batch_video_files)} video(s)...")
|
||||||
|
|
||||||
|
if not batch_video_files:
|
||||||
|
print("No video files found in the specified directory.")
|
||||||
|
return
|
||||||
|
|
||||||
|
batch_input = [
|
||||||
|
{
|
||||||
|
"num_frames": num_frames,
|
||||||
|
"video_path": video_path,
|
||||||
|
"question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
|
||||||
|
} for video_path in batch_video_files
|
||||||
|
]
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
average_time = total_time / len(batch_video_files)
|
||||||
|
print(f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds")
|
||||||
|
|
||||||
|
save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# Create the parser
|
||||||
|
parser = argparse.ArgumentParser(description='Run video processing with specified port.')
|
||||||
|
|
||||||
|
# Add an argument for the port
|
||||||
|
parser.add_argument('--port', type=int, default=30000, help='The master port for distributed serving.')
|
||||||
|
parser.add_argument('--chunk-idx', type=int, default=0, help='The index of the chunk to process.')
|
||||||
|
parser.add_argument('--num-chunks', type=int, default=8, help='The number of chunks to process.')
|
||||||
|
parser.add_argument('--save-dir', type=str, default="./work_dirs/llava_video", help='The directory to save the processed video files.')
|
||||||
|
parser.add_argument('--video-dir', type=str, default="./videos/Q98Z4OTh8RwmDonc.mp4", help='The directory or path for the processed video files.')
|
||||||
|
parser.add_argument('--model-path', type=str, default="lmms-lab/LLaVA-NeXT-Video-7B", help='The model path for the video processing.')
|
||||||
|
parser.add_argument('--num-frames', type=int, default=16, help='The number of frames to process in each video.' )
|
||||||
|
parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)
|
||||||
|
|
||||||
|
# Parse the arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
cur_port = args.port
|
||||||
|
|
||||||
|
cur_chunk = args.chunk_idx
|
||||||
|
|
||||||
|
num_chunks = args.num_chunks
|
||||||
|
|
||||||
|
num_frames = args.num_frames
|
||||||
|
|
||||||
|
if "34b" in args.model_path.lower():
|
||||||
|
tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer"
|
||||||
|
elif "7b" in args.model_path.lower():
|
||||||
|
tokenizer_path = "llava-hf/llava-1.5-7b-hf"
|
||||||
|
else:
|
||||||
|
print("Invalid model path. Please specify a valid model path.")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
model_overide_args = {}
|
||||||
|
|
||||||
|
model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
|
||||||
|
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||||
|
model_overide_args["num_frames"] = args.num_frames
|
||||||
|
model_overide_args["model_type"] = "llava"
|
||||||
|
|
||||||
|
if "34b" in args.model_path.lower():
|
||||||
|
model_overide_args["image_token_index"] = 64002
|
||||||
|
|
||||||
|
|
||||||
|
if args.num_frames == 32:
|
||||||
|
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
||||||
|
model_overide_args["max_sequence_length"] = 4096 * 2
|
||||||
|
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
||||||
|
elif args.num_frames < 32:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("The maximum number of frames to process is 32. Please specify a valid number of frames.")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
runtime = sgl.Runtime(
|
||||||
|
model_path=args.model_path, #"liuhaotian/llava-v1.6-vicuna-7b",
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
port=cur_port,
|
||||||
|
additional_ports=[cur_port+1,cur_port+2,cur_port+3,cur_port+4],
|
||||||
|
model_overide_args=model_overide_args,
|
||||||
|
tp_size=1
|
||||||
|
)
|
||||||
|
sgl.set_default_backend(runtime)
|
||||||
|
print(f"chat template: {runtime.endpoint.chat_template.name}")
|
||||||
|
|
||||||
|
|
||||||
|
# Run a single request
|
||||||
|
# try:
|
||||||
|
print("\n========== single ==========\n")
|
||||||
|
root = args.video_dir
|
||||||
|
if os.path.isfile(root):
|
||||||
|
video_files = [root]
|
||||||
|
else:
|
||||||
|
video_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.mp4', '.avi', '.mov'))] # Add more extensions if needed
|
||||||
|
start_time = time.time() # Start time for processing a single video
|
||||||
|
for cur_video in video_files[:1]:
|
||||||
|
print(cur_video)
|
||||||
|
single(cur_video, num_frames)
|
||||||
|
end_time = time.time() # End time for processing a single video
|
||||||
|
total_time = end_time - start_time
|
||||||
|
average_time = total_time / len(video_files) # Calculate the average processing time
|
||||||
|
print(f"Average processing time per video: {average_time:.2f} seconds")
|
||||||
|
runtime.shutdown()
|
||||||
|
# except Exception as e:
|
||||||
|
# print(e)
|
||||||
|
runtime.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
# # # Run a batch of requests
|
||||||
|
# print("\n========== batch ==========\n")
|
||||||
|
# if not os.path.exists(args.save_dir):
|
||||||
|
# os.makedirs(args.save_dir)
|
||||||
|
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
|
||||||
|
# runtime.shutdown()
|
||||||
130
examples/usage/llava_video/srt_example_llava_v.sh
Normal file
130
examples/usage/llava_video/srt_example_llava_v.sh
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
##### USAGE #####
|
||||||
|
# - First node:
|
||||||
|
# ```sh
|
||||||
|
# bash examples/quick_start/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
||||||
|
# ```
|
||||||
|
# - Second node:
|
||||||
|
# ```sh
|
||||||
|
# bash examples/quick_start/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
||||||
|
# ```
|
||||||
|
# - The K node:
|
||||||
|
# ```sh
|
||||||
|
# bash examples/quick_start/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO
|
||||||
|
# ```
|
||||||
|
|
||||||
|
|
||||||
|
# Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details.
|
||||||
|
CURRENT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
|
|
||||||
|
echo ${CURRENT_ROOT}
|
||||||
|
|
||||||
|
cd ${CURRENT_ROOT}
|
||||||
|
|
||||||
|
export PYTHONWARNINGS=ignore
|
||||||
|
|
||||||
|
START_TIME=$(date +%s) # Capture start time
|
||||||
|
|
||||||
|
NUM_NODES=$1
|
||||||
|
|
||||||
|
CUR_NODES_IDX=$2
|
||||||
|
|
||||||
|
VIDEO_DIR=$3
|
||||||
|
|
||||||
|
MODEL_PATH=$4
|
||||||
|
|
||||||
|
NUM_FRAMES=$5
|
||||||
|
|
||||||
|
|
||||||
|
# FRAME_FORMAT=$6
|
||||||
|
|
||||||
|
# FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]')
|
||||||
|
|
||||||
|
# # Check if FRAME_FORMAT is either JPEG or PNG
|
||||||
|
# if [[ "$FRAME_FORMAT" != "JPEG" && "$FRAME_FORMAT" != "PNG" ]]; then
|
||||||
|
# echo "Error: FRAME_FORMAT must be either JPEG or PNG."
|
||||||
|
# exit 1
|
||||||
|
# fi
|
||||||
|
|
||||||
|
# export TARGET_FRAMES=$TARGET_FRAMES
|
||||||
|
|
||||||
|
echo "Each video you will sample $NUM_FRAMES frames"
|
||||||
|
|
||||||
|
# export FRAME_FORMAT=$FRAME_FORMAT
|
||||||
|
|
||||||
|
# echo "The frame format is $FRAME_FORMAT"
|
||||||
|
|
||||||
|
# Assuming GPULIST is a bash array containing your GPUs
|
||||||
|
GPULIST=(0 1 2 3 4 5 6 7)
|
||||||
|
LOCAL_CHUNKS=${#GPULIST[@]}
|
||||||
|
|
||||||
|
echo "Number of GPUs in GPULIST: $LOCAL_CHUNKS"
|
||||||
|
|
||||||
|
ALL_CHUNKS=$((NUM_NODES * LOCAL_CHUNKS))
|
||||||
|
|
||||||
|
# Calculate GPUs per chunk
|
||||||
|
GPUS_PER_CHUNK=8
|
||||||
|
|
||||||
|
echo $GPUS_PER_CHUNK
|
||||||
|
|
||||||
|
for IDX in $(seq 1 $LOCAL_CHUNKS); do
|
||||||
|
(
|
||||||
|
START=$(((IDX-1) * GPUS_PER_CHUNK))
|
||||||
|
LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
|
||||||
|
|
||||||
|
CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
|
||||||
|
|
||||||
|
# Convert the chunk GPUs array to a comma-separated string
|
||||||
|
CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
|
||||||
|
|
||||||
|
LOCAL_IDX=$((CUR_NODES_IDX * LOCAL_CHUNKS + IDX))
|
||||||
|
|
||||||
|
echo "Chunk $(($LOCAL_IDX - 1)) will run on GPUs $CHUNK_GPUS_STR"
|
||||||
|
|
||||||
|
# Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk.
|
||||||
|
PORT=$((10000 + RANDOM % 55536))
|
||||||
|
|
||||||
|
MAX_RETRIES=10
|
||||||
|
RETRY_COUNT=0
|
||||||
|
COMMAND_STATUS=1 # Initialize as failed
|
||||||
|
|
||||||
|
while [ $RETRY_COUNT -lt $MAX_RETRIES ] && [ $COMMAND_STATUS -ne 0 ]; do
|
||||||
|
echo "Running chunk $(($LOCAL_IDX - 1)) on GPUs $CHUNK_GPUS_STR with port $PORT. Attempt $(($RETRY_COUNT + 1))"
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 examples/usage/llava_video/srt_example_llava_v.py \
|
||||||
|
--port $PORT \
|
||||||
|
--num-chunks $ALL_CHUNKS \
|
||||||
|
--chunk-idx $(($LOCAL_IDX - 1)) \
|
||||||
|
--save-dir work_dirs/llava_next_video_inference_results \
|
||||||
|
--video-dir $VIDEO_DIR \
|
||||||
|
--model-path $MODEL_PATH \
|
||||||
|
--num-frames $NUM_FRAMES #&
|
||||||
|
|
||||||
|
wait $! # Wait for the process to finish and capture its exit status
|
||||||
|
COMMAND_STATUS=$?
|
||||||
|
|
||||||
|
if [ $COMMAND_STATUS -ne 0 ]; then
|
||||||
|
echo "Execution failed for chunk $(($LOCAL_IDX - 1)), attempt $(($RETRY_COUNT + 1)). Retrying..."
|
||||||
|
RETRY_COUNT=$(($RETRY_COUNT + 1))
|
||||||
|
sleep 180 # Wait a bit before retrying
|
||||||
|
else
|
||||||
|
echo "Execution succeeded for chunk $(($LOCAL_IDX - 1))."
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ $COMMAND_STATUS -ne 0 ]; then
|
||||||
|
echo "Execution failed for chunk $(($LOCAL_IDX - 1)) after $MAX_RETRIES attempts."
|
||||||
|
fi
|
||||||
|
) #&
|
||||||
|
sleep 2 # Slight delay to stagger the start times
|
||||||
|
done
|
||||||
|
|
||||||
|
wait
|
||||||
|
|
||||||
|
cat work_dirs/llava_next_video_inference_results/final_results_chunk_*.csv > work_dirs/llava_next_video_inference_results/final_results_node_${CUR_NODES_IDX}.csv
|
||||||
|
|
||||||
|
END_TIME=$(date +%s) # Capture end time
|
||||||
|
ELAPSED_TIME=$(($END_TIME - $START_TIME))
|
||||||
|
echo "Total execution time: $ELAPSED_TIME seconds."
|
||||||
BIN
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
Normal file
BIN
examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4
Normal file
Binary file not shown.
@@ -20,7 +20,7 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||||
"zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "outlines>=0.0.27", "packaging"]
|
"zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"]
|
||||||
openai = ["openai>=1.0", "numpy", "tiktoken"]
|
openai = ["openai>=1.0", "numpy", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0", "numpy"]
|
anthropic = ["anthropic>=0.20.0", "numpy"]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from sglang.api import (
|
|||||||
user,
|
user,
|
||||||
user_begin,
|
user_begin,
|
||||||
user_end,
|
user_end,
|
||||||
|
video,
|
||||||
)
|
)
|
||||||
|
|
||||||
# SGL Backends
|
# SGL Backends
|
||||||
@@ -46,6 +47,7 @@ __all__ = [
|
|||||||
"gen_int",
|
"gen_int",
|
||||||
"gen_string",
|
"gen_string",
|
||||||
"image",
|
"image",
|
||||||
|
"video",
|
||||||
"select",
|
"select",
|
||||||
"system",
|
"system",
|
||||||
"user",
|
"user",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from sglang.lang.ir import (
|
|||||||
SglRoleBegin,
|
SglRoleBegin,
|
||||||
SglRoleEnd,
|
SglRoleEnd,
|
||||||
SglSelect,
|
SglSelect,
|
||||||
|
SglVideo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -151,6 +152,10 @@ def image(expr: SglExpr):
|
|||||||
return SglImage(expr)
|
return SglImage(expr)
|
||||||
|
|
||||||
|
|
||||||
|
def video(path: str, num_frames: int):
|
||||||
|
return SglVideo(path, num_frames)
|
||||||
|
|
||||||
|
|
||||||
def select(
|
def select(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
choices: List[str] = None,
|
choices: List[str] = None,
|
||||||
|
|||||||
@@ -259,6 +259,8 @@ def match_vicuna(model_path: str):
|
|||||||
return get_chat_template("vicuna_v1.1")
|
return get_chat_template("vicuna_v1.1")
|
||||||
if "llava-v1.5" in model_path.lower():
|
if "llava-v1.5" in model_path.lower():
|
||||||
return get_chat_template("vicuna_v1.1")
|
return get_chat_template("vicuna_v1.1")
|
||||||
|
if "llava-next-video-7b" in model_path.lower():
|
||||||
|
return get_chat_template("vicuna_v1.1")
|
||||||
|
|
||||||
|
|
||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
@@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str):
|
|||||||
|
|
||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
def match_chat_ml(model_path: str):
|
def match_chat_ml(model_path: str):
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
model_path = model_path.lower()
|
model_path = model_path.lower()
|
||||||
if "tinyllama" in model_path:
|
if "tinyllama" in model_path:
|
||||||
return get_chat_template("chatml")
|
return get_chat_template("chatml")
|
||||||
if "qwen" in model_path and "chat" in model_path:
|
if "qwen" in model_path and "chat" in model_path:
|
||||||
return get_chat_template("chatml")
|
return get_chat_template("chatml")
|
||||||
if "llava-v1.6-34b" in model_path:
|
if (
|
||||||
|
"llava-v1.6-34b" in model_path
|
||||||
|
or "llava-v1.6-yi-34b" in model_path
|
||||||
|
or "llava-next-video-34b" in model_path
|
||||||
|
):
|
||||||
return get_chat_template("chatml-llava")
|
return get_chat_template("chatml-llava")
|
||||||
|
|
||||||
|
|
||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
def match_chat_yi(model_path: str):
|
def match_chat_yi(model_path: str):
|
||||||
model_path = model_path.lower()
|
model_path = model_path.lower()
|
||||||
if "yi" in model_path:
|
if "yi" in model_path and "llava" not in model_path:
|
||||||
return get_chat_template("yi")
|
return get_chat_template("yi")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,8 +28,9 @@ from sglang.lang.ir import (
|
|||||||
SglVariable,
|
SglVariable,
|
||||||
SglVarScopeBegin,
|
SglVarScopeBegin,
|
||||||
SglVarScopeEnd,
|
SglVarScopeEnd,
|
||||||
|
SglVideo,
|
||||||
)
|
)
|
||||||
from sglang.utils import encode_image_base64, get_exception_traceback
|
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
|
||||||
|
|
||||||
|
|
||||||
def run_internal(state, program, func_args, func_kwargs, sync):
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
||||||
@@ -361,6 +362,8 @@ class StreamExecutor:
|
|||||||
self._execute_role_end(other)
|
self._execute_role_end(other)
|
||||||
elif isinstance(other, SglImage):
|
elif isinstance(other, SglImage):
|
||||||
self._execute_image(other)
|
self._execute_image(other)
|
||||||
|
elif isinstance(other, SglVideo):
|
||||||
|
self._execute_video(other)
|
||||||
elif isinstance(other, SglVariable):
|
elif isinstance(other, SglVariable):
|
||||||
self._execute_variable(other)
|
self._execute_variable(other)
|
||||||
elif isinstance(other, SglVarScopeBegin):
|
elif isinstance(other, SglVarScopeBegin):
|
||||||
@@ -397,6 +400,16 @@ class StreamExecutor:
|
|||||||
self.cur_images.append((path, base64_data))
|
self.cur_images.append((path, base64_data))
|
||||||
self.text_ += self.chat_template.image_token
|
self.text_ += self.chat_template.image_token
|
||||||
|
|
||||||
|
def _execute_video(self, expr: SglVideo):
|
||||||
|
path = expr.path
|
||||||
|
num_frames = expr.num_frames
|
||||||
|
|
||||||
|
base64_data = encode_video_base64(path, num_frames)
|
||||||
|
|
||||||
|
self.images_.append((path, base64_data))
|
||||||
|
self.cur_images.append((path, base64_data))
|
||||||
|
self.text_ += self.chat_template.image_token
|
||||||
|
|
||||||
# if global_config.eager_fill_image:
|
# if global_config.eager_fill_image:
|
||||||
# self.backend.fill_image(self)
|
# self.backend.fill_image(self)
|
||||||
|
|
||||||
|
|||||||
@@ -330,6 +330,15 @@ class SglImage(SglExpr):
|
|||||||
return f"SglImage({self.path})"
|
return f"SglImage({self.path})"
|
||||||
|
|
||||||
|
|
||||||
|
class SglVideo(SglExpr):
|
||||||
|
def __init__(self, path, num_frames):
|
||||||
|
self.path = path
|
||||||
|
self.num_frames = num_frames
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"SglVideo({self.path}, {self.num_frames})"
|
||||||
|
|
||||||
|
|
||||||
class SglGen(SglExpr):
|
class SglGen(SglExpr):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
|
|||||||
##################################
|
##################################
|
||||||
|
|
||||||
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||||
assert (size >= 1)
|
assert size >= 1
|
||||||
|
|
||||||
if self.only_trace_prefix:
|
if self.only_trace_prefix:
|
||||||
raise StopTracing()
|
raise StopTracing()
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import argparse
|
|||||||
|
|
||||||
from sglang.srt.server import ServerArgs, launch_server
|
from sglang.srt.server import ServerArgs, launch_server
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
ServerArgs.add_cli_args(parser)
|
ServerArgs.add_cli_args(parser)
|
||||||
|
|||||||
31
python/sglang/launch_server_llavavid.py
Normal file
31
python/sglang/launch_server_llavavid.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import argparse
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
from sglang.srt.server import ServerArgs, launch_server
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
model_overide_args = {}
|
||||||
|
|
||||||
|
model_overide_args["mm_spatial_pool_stride"] = 2
|
||||||
|
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||||
|
model_overide_args["num_frames"] = 16
|
||||||
|
model_overide_args["model_type"] = "llavavid"
|
||||||
|
if model_overide_args["num_frames"] == 32:
|
||||||
|
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
||||||
|
model_overide_args["max_sequence_length"] = 4096 * 2
|
||||||
|
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
||||||
|
model_overide_args["model_max_length"] = 4096 * 2
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if "34b" in args.model_path.lower():
|
||||||
|
model_overide_args["image_token_index"] = 64002
|
||||||
|
|
||||||
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||||
|
|
||||||
|
launch_server(server_args, pipe_writer, model_overide_args)
|
||||||
@@ -30,10 +30,17 @@ def get_config_json(model_path: str):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
|
def get_config(
|
||||||
|
model: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
model_overide_args: Optional[dict] = None,
|
||||||
|
):
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model, trust_remote_code=trust_remote_code, revision=revision
|
model, trust_remote_code=trust_remote_code, revision=revision
|
||||||
)
|
)
|
||||||
|
if model_overide_args:
|
||||||
|
config.update(model_overide_args)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -60,9 +60,7 @@ class RouterManager:
|
|||||||
|
|
||||||
|
|
||||||
def start_router_process(
|
def start_router_process(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
||||||
port_args: PortArgs,
|
|
||||||
pipe_writer,
|
|
||||||
):
|
):
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=getattr(logging, server_args.log_level.upper()),
|
level=getattr(logging, server_args.log_level.upper()),
|
||||||
@@ -70,7 +68,7 @@ def start_router_process(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_client = ModelRpcClient(server_args, port_args)
|
model_client = ModelRpcClient(server_args, port_args, model_overide_args)
|
||||||
router = RouterManager(model_client, port_args)
|
router = RouterManager(model_client, port_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ import multiprocessing
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import rpyc
|
import rpyc
|
||||||
import torch
|
import torch
|
||||||
from rpyc.utils.classic import obtain
|
from rpyc.utils.classic import obtain
|
||||||
from rpyc.utils.server import ThreadedServer
|
from rpyc.utils.server import ThreadedServer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.logger import _default_handler as vllm_default_logger
|
from vllm.logger import _default_handler as vllm_default_logger
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -48,6 +49,7 @@ class ModelRpcServer:
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
|
model_overide_args: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
||||||
|
|
||||||
@@ -62,6 +64,7 @@ class ModelRpcServer:
|
|||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
server_args.trust_remote_code,
|
server_args.trust_remote_code,
|
||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
|
model_overide_args=model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For model end global settings
|
# For model end global settings
|
||||||
@@ -673,13 +676,15 @@ class ModelRpcService(rpyc.Service):
|
|||||||
|
|
||||||
|
|
||||||
class ModelRpcClient:
|
class ModelRpcClient:
|
||||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
def __init__(
|
||||||
|
self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
|
||||||
|
):
|
||||||
tp_size = server_args.tp_size
|
tp_size = server_args.tp_size
|
||||||
|
|
||||||
if tp_size == 1:
|
if tp_size == 1:
|
||||||
# Init model
|
# Init model
|
||||||
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
self.model_server = ModelRpcService().exposed_ModelRpcServer(
|
||||||
0, server_args, port_args
|
0, server_args, port_args, model_overide_args
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wrap functions
|
# Wrap functions
|
||||||
@@ -700,7 +705,7 @@ class ModelRpcClient:
|
|||||||
# Init model
|
# Init model
|
||||||
def init_model(i):
|
def init_model(i):
|
||||||
return self.remote_services[i].ModelRpcServer(
|
return self.remote_services[i].ModelRpcServer(
|
||||||
i, server_args, port_args
|
i, server_args, port_args, model_overide_args
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_servers = executor.map(init_model, range(tp_size))
|
self.model_servers = executor.map(init_model, range(tp_size))
|
||||||
@@ -723,7 +728,11 @@ def _init_service(port):
|
|||||||
t = ThreadedServer(
|
t = ThreadedServer(
|
||||||
ModelRpcService(),
|
ModelRpcService(),
|
||||||
port=port,
|
port=port,
|
||||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
|
protocol_config={
|
||||||
|
"allow_public_attrs": True,
|
||||||
|
"allow_pickle": True,
|
||||||
|
"sync_request_timeout": 1800,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
@@ -739,7 +748,11 @@ def start_model_process(port):
|
|||||||
con = rpyc.connect(
|
con = rpyc.connect(
|
||||||
"localhost",
|
"localhost",
|
||||||
port,
|
port,
|
||||||
config={"allow_pickle": True, "sync_request_timeout": 1800},
|
config={
|
||||||
|
"allow_public_attrs": True,
|
||||||
|
"allow_pickle": True,
|
||||||
|
"sync_request_timeout": 1800,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except ConnectionRefusedError:
|
except ConnectionRefusedError:
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from vllm.distributed import initialize_model_parallel
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.distributed import initialize_model_parallel
|
|
||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
@@ -143,7 +143,7 @@ class InputMetadata:
|
|||||||
self.kv_last_page_len,
|
self.kv_last_page_len,
|
||||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||||
self.model_runner.model_config.head_dim
|
self.model_runner.model_config.head_dim,
|
||||||
]
|
]
|
||||||
|
|
||||||
self.prefill_wrapper.begin_forward(*args)
|
self.prefill_wrapper.begin_forward(*args)
|
||||||
|
|||||||
@@ -60,21 +60,29 @@ def get_pixel_values(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
processor = processor or global_processor
|
processor = processor or global_processor
|
||||||
image = load_image(image_data)
|
image, image_size = load_image(image_data)
|
||||||
image_hash = hash(image_data)
|
if image_size != None:
|
||||||
if image_aspect_ratio == "pad":
|
image_hash = hash(image_data)
|
||||||
image = expand2square(
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||||
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
for _ in range(len(pixel_values)):
|
||||||
)
|
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
pixel_values = np.stack(pixel_values, axis=0)
|
||||||
elif image_aspect_ratio == "anyres":
|
return pixel_values, image_hash, image_size
|
||||||
pixel_values = process_anyres_image(
|
|
||||||
image, processor.image_processor, image_grid_pinpoints
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
image_hash = hash(image_data)
|
||||||
pixel_values = pixel_values.astype(np.float16)
|
if image_aspect_ratio == "pad":
|
||||||
return pixel_values, image_hash, image.size
|
image = expand2square(
|
||||||
|
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
|
||||||
|
)
|
||||||
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||||
|
elif image_aspect_ratio == "anyres":
|
||||||
|
pixel_values = process_anyres_image(
|
||||||
|
image, processor.image_processor, image_grid_pinpoints
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||||
|
pixel_values = pixel_values.astype(np.float16)
|
||||||
|
return pixel_values, image_hash, image.size
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||||
|
|
||||||
@@ -84,6 +92,7 @@ class TokenizerManager:
|
|||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
|
model_overide_args: dict = None,
|
||||||
):
|
):
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
|
||||||
@@ -96,7 +105,9 @@ class TokenizerManager:
|
|||||||
|
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
self.model_path, trust_remote_code=server_args.trust_remote_code
|
self.model_path,
|
||||||
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
model_overide_args=model_overide_args,
|
||||||
)
|
)
|
||||||
self.context_len = get_context_length(self.hf_config)
|
self.context_len = get_context_length(self.hf_config)
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,16 @@ class ModelConfig:
|
|||||||
trust_remote_code: bool = True,
|
trust_remote_code: bool = True,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
|
model_overide_args: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.trust_remote_code = trust_remote_code
|
self.trust_remote_code = trust_remote_code
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
||||||
|
|
||||||
|
if model_overide_args is not None:
|
||||||
|
self.hf_config.update(model_overide_args)
|
||||||
|
|
||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
self.context_len = context_length
|
self.context_len = context_length
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -27,29 +27,25 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
|
|||||||
@@ -5,37 +5,31 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE,
|
DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.distributed import (
|
|
||||||
tensor_model_parallel_all_reduce,
|
|
||||||
)
|
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
from sglang.srt.models.dbrx_config import DbrxConfig
|
from sglang.srt.models.dbrx_config import DbrxConfig
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
class DbrxRouter(nn.Module):
|
class DbrxRouter(nn.Module):
|
||||||
@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config)
|
self.norm_attn_norm = DbrxFusedNormAttention(
|
||||||
|
config, layer_id, quant_config=quant_config
|
||||||
|
)
|
||||||
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
|
|||||||
config.d_model,
|
config.d_model,
|
||||||
)
|
)
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)]
|
[
|
||||||
|
DbrxBlock(config, i, quant_config=quant_config)
|
||||||
|
for i in range(config.n_layers)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
class GemmaMLP(nn.Module):
|
class GemmaMLP(nn.Module):
|
||||||
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act_fn = GeluAndMul()
|
self.act_fn = GeluAndMul()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -7,12 +7,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CLIPVisionModel, LlavaConfig
|
from transformers import CLIPVisionModel, LlavaConfig
|
||||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
|
|||||||
unpad_image_shape,
|
unpad_image_shape,
|
||||||
)
|
)
|
||||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
class LlavaLlamaForCausalLM(nn.Module):
|
class LlavaLlamaForCausalLM(nn.Module):
|
||||||
|
|||||||
307
python/sglang/srt/models/llavavid.py
Normal file
307
python/sglang/srt/models/llavavid.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
|
||||||
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
|
||||||
|
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||||
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.mm_utils import (
|
||||||
|
get_anyres_image_grid_shape,
|
||||||
|
unpad_image,
|
||||||
|
unpad_image_shape,
|
||||||
|
)
|
||||||
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaVidForCausalLM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlavaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vision_tower = None
|
||||||
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||||
|
self.config.text_config.hidden_size = config.hidden_size
|
||||||
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||||
|
self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
|
||||||
|
self.resampler = nn.AvgPool2d(
|
||||||
|
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
|
||||||
|
)
|
||||||
|
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||||
|
self.num_frames = getattr(self.config, "num_frames", 16)
|
||||||
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||||
|
self.language_model.model.image_newline = nn.Parameter(
|
||||||
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
|
)
|
||||||
|
|
||||||
|
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
||||||
|
new_image_feature_len = self.image_feature_len
|
||||||
|
# now only support spatial_unpad + anyres
|
||||||
|
# if self.mm_patch_merge_type.startswith("spatial"):
|
||||||
|
# height = width = self.num_patches_per_side
|
||||||
|
# if pt_shape[0] > 1:
|
||||||
|
# if self.image_aspect_ratio == "anyres":
|
||||||
|
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||||
|
# image_size,
|
||||||
|
# self.image_grid_pinpoints,
|
||||||
|
# self.vision_tower.config.image_size,
|
||||||
|
# )
|
||||||
|
# if "unpad" in self.mm_patch_merge_type:
|
||||||
|
# h = num_patch_height * height
|
||||||
|
# w = num_patch_width * width
|
||||||
|
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
||||||
|
# new_image_feature_len += new_h * (new_w + 1)
|
||||||
|
|
||||||
|
pad_ids = pad_value * (
|
||||||
|
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||||
|
)
|
||||||
|
# print(input_ids)
|
||||||
|
offset = input_ids.index(self.config.image_token_index)
|
||||||
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||||
|
new_input_ids = (
|
||||||
|
input_ids[:offset]
|
||||||
|
+ pad_ids[:new_image_feature_len]
|
||||||
|
+ input_ids[offset + 1 :]
|
||||||
|
)
|
||||||
|
return new_input_ids, offset
|
||||||
|
|
||||||
|
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
|
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||||
|
|
||||||
|
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
||||||
|
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||||
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
|
elif self.vision_feature_select_strategy == "full":
|
||||||
|
selected_image_feature = selected_image_feature
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||||
|
)
|
||||||
|
|
||||||
|
height = width = self.num_patches_per_side
|
||||||
|
num_of_frames = selected_image_feature.shape[0]
|
||||||
|
selected_image_feature = selected_image_feature.view(
|
||||||
|
num_of_frames, height, width, -1
|
||||||
|
)
|
||||||
|
selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
|
||||||
|
selected_image_feature = (
|
||||||
|
self.resampler(selected_image_feature)
|
||||||
|
.flatten(2)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||||
|
image_sizes: Optional[List[List[int]]] = None,
|
||||||
|
image_offsets: Optional[List[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||||
|
bs = input_metadata.batch_size
|
||||||
|
|
||||||
|
# Embed text input
|
||||||
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Embed vision input
|
||||||
|
need_vision = (
|
||||||
|
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
# FIXME: We need to substract the length of the system prompt
|
||||||
|
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||||
|
need_vision = need_vision & has_pixel
|
||||||
|
|
||||||
|
if need_vision.any():
|
||||||
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||||
|
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||||
|
|
||||||
|
########## Encode Image ########
|
||||||
|
|
||||||
|
if pixel_values[0].ndim == 4:
|
||||||
|
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
||||||
|
np.concatenate(pixel_values, axis=0)
|
||||||
|
# ndim=4
|
||||||
|
concat_images = torch.tensor(
|
||||||
|
np.concatenate(pixel_values, axis=0),
|
||||||
|
device=self.vision_tower.device,
|
||||||
|
)
|
||||||
|
# image_features = self.encode_images(concat_images)
|
||||||
|
# split_sizes = [image.shape[0] for image in pixel_values]
|
||||||
|
# image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
image_features = self.encode_images(
|
||||||
|
concat_images
|
||||||
|
) # , prompts)#, image_counts, long_video=long_video)
|
||||||
|
split_sizes = [image.shape[0] for image in pixel_values]
|
||||||
|
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||||
|
|
||||||
|
# hd image_features: BS, num_patch, 576, 4096
|
||||||
|
else:
|
||||||
|
# normal pixel: BS, C=3, H=336, W=336
|
||||||
|
pixel_values = torch.tensor(
|
||||||
|
np.array(pixel_values), device=self.vision_tower.device
|
||||||
|
)
|
||||||
|
image_features = self.encode_images(pixel_values)
|
||||||
|
# image_features: BS, 576, 4096
|
||||||
|
|
||||||
|
new_image_features = []
|
||||||
|
for image_idx, image_feature in enumerate(image_features):
|
||||||
|
new_image_features.append(image_feature.flatten(0, 1))
|
||||||
|
image_features = new_image_features
|
||||||
|
|
||||||
|
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||||
|
pt = 0
|
||||||
|
for i in range(bs):
|
||||||
|
if not need_vision[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_idx = extend_start_loc_cpu[i]
|
||||||
|
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
||||||
|
dim = input_embeds.shape[1]
|
||||||
|
assert (
|
||||||
|
pad_dim == dim
|
||||||
|
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||||
|
# Fill in the placeholder for the image
|
||||||
|
try:
|
||||||
|
input_embeds[
|
||||||
|
start_idx
|
||||||
|
+ image_offsets[i] : start_idx
|
||||||
|
+ image_offsets[i]
|
||||||
|
+ pad_len
|
||||||
|
] = image_features[pt]
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"RuntimeError in llava image encoding: {e}")
|
||||||
|
print(input_embeds.shape)
|
||||||
|
print(start_idx, image_offsets[i])
|
||||||
|
pt += 1
|
||||||
|
|
||||||
|
return self.language_model(
|
||||||
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
||||||
|
)
|
||||||
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
|
return self.language_model(input_ids, positions, input_metadata)
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
):
|
||||||
|
# load clip vision model by cfg['mm_vision_tower']:
|
||||||
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||||
|
vision_path = self.config.mm_vision_tower
|
||||||
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||||
|
vision_path, torch_dtype=torch.float16
|
||||||
|
).cuda()
|
||||||
|
self.vision_tower.eval()
|
||||||
|
|
||||||
|
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||||
|
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
||||||
|
self.image_size = self.vision_tower.config.image_size
|
||||||
|
self.patch_size = self.vision_tower.config.patch_size
|
||||||
|
|
||||||
|
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
||||||
|
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||||
|
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
||||||
|
|
||||||
|
print(f"target_frames: {self.num_frames}")
|
||||||
|
self.image_feature_len = self.num_frames * int(
|
||||||
|
(self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
|
||||||
|
)
|
||||||
|
if self.vision_feature_select_strategy == "patch":
|
||||||
|
pass
|
||||||
|
elif self.vision_feature_select_strategy == "cls_patch":
|
||||||
|
self.image_feature_len += 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||||
|
|
||||||
|
# load mm_projector
|
||||||
|
projector_weights = {
|
||||||
|
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||||
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||||
|
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
||||||
|
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
||||||
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||||
|
}
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision
|
||||||
|
):
|
||||||
|
# FIXME: why projector weights read two times?
|
||||||
|
if "projector" in name or "vision_tower" in name:
|
||||||
|
for weight_name, param_name in projector_weights.items():
|
||||||
|
if weight_name in name:
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
if name in params_dict:
|
||||||
|
param = params_dict[name]
|
||||||
|
else:
|
||||||
|
print(f"Warning: {name} not found in the model")
|
||||||
|
continue
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# load language model
|
||||||
|
self.language_model.load_weights(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision
|
||||||
|
)
|
||||||
|
|
||||||
|
monkey_path_clip_vision_embed_forward()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_patches_per_side(self):
|
||||||
|
return self.image_size // self.patch_size
|
||||||
|
|
||||||
|
|
||||||
|
first_call = True
|
||||||
|
|
||||||
|
|
||||||
|
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
|
||||||
|
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||||
|
global first_call
|
||||||
|
if first_call:
|
||||||
|
self.patch_embedding.cpu().float()
|
||||||
|
first_call = False
|
||||||
|
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||||
|
|
||||||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_path_clip_vision_embed_forward():
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||||
|
"forward",
|
||||||
|
clip_vision_embed_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = LlavaVidForCausalLM
|
||||||
@@ -8,34 +8,28 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.distributed import (
|
|
||||||
tensor_model_parallel_all_reduce,
|
|
||||||
)
|
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
class MixtralMLP(nn.Module):
|
class MixtralMLP(nn.Module):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
class QWenMLP(nn.Module):
|
class QWenMLP(nn.Module):
|
||||||
@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class QWenBlock(nn.Module):
|
class QWenBlock(nn.Module):
|
||||||
def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
layer_id,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class QWenModel(nn.Module):
|
class QWenModel(nn.Module):
|
||||||
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class QWenLMHeadModel(nn.Module):
|
class QWenLMHeadModel(nn.Module):
|
||||||
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.transformer = QWenModel(config, quant_config=quant_config)
|
self.transformer = QWenModel(config, quant_config=quant_config)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
Qwen2Config = None
|
Qwen2Config = None
|
||||||
|
|
||||||
@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -7,35 +7,31 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.distributed import (
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
class StablelmMLP(nn.Module):
|
class StablelmMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -48,7 +44,10 @@ class StablelmMLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
|
config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
@@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class StableLMEpochModel(nn.Module):
|
class StableLMEpochModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
|||||||
@@ -6,16 +6,13 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import CLIPVisionModel, LlavaConfig
|
from transformers import CLIPVisionModel, LlavaConfig
|
||||||
from sglang.srt.weight_utils import (
|
|
||||||
default_weight_loader,
|
|
||||||
hf_model_weights_iterator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.models.llava import (
|
from sglang.srt.models.llava import (
|
||||||
LlavaLlamaForCausalLM,
|
LlavaLlamaForCausalLM,
|
||||||
clip_vision_embed_forward,
|
clip_vision_embed_forward,
|
||||||
monkey_path_clip_vision_embed_forward,
|
monkey_path_clip_vision_embed_forward,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
||||||
|
|
||||||
|
|
||||||
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|||||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -140,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Launch processes
|
# Launch processes
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||||
|
|
||||||
proc_router = mp.Process(
|
proc_router = mp.Process(
|
||||||
target=start_router_process,
|
target=start_router_process,
|
||||||
args=(
|
args=(server_args, port_args, pipe_router_writer, model_overide_args),
|
||||||
server_args,
|
|
||||||
port_args,
|
|
||||||
pipe_router_writer,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
proc_router.start()
|
proc_router.start()
|
||||||
proc_detoken = mp.Process(
|
proc_detoken = mp.Process(
|
||||||
@@ -170,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|||||||
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
||||||
proc_router.kill()
|
proc_router.kill()
|
||||||
proc_detoken.kill()
|
proc_detoken.kill()
|
||||||
print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
|
print(
|
||||||
print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
|
f"Initialization failed. router_init_state: {router_init_state}", flush=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||||
|
|
||||||
@@ -189,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
try:
|
try:
|
||||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||||
|
success = True # Set flag to True if request succeeds
|
||||||
break
|
break
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
pass
|
pass
|
||||||
@@ -205,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=60,
|
timeout=600,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -235,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
|||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
log_evel="error",
|
log_evel: str = "error",
|
||||||
|
model_overide_args: Optional[dict] = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -244,7 +247,10 @@ class Runtime:
|
|||||||
|
|
||||||
# Pre-allocate ports
|
# Pre-allocate ports
|
||||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||||
self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
|
self.server_args.port,
|
||||||
|
self.server_args.additional_ports,
|
||||||
|
self.server_args.tp_size,
|
||||||
|
)
|
||||||
|
|
||||||
self.url = self.server_args.url()
|
self.url = self.server_args.url()
|
||||||
self.generate_url = (
|
self.generate_url = (
|
||||||
@@ -253,7 +259,10 @@ class Runtime:
|
|||||||
|
|
||||||
self.pid = None
|
self.pid = None
|
||||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||||
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
|
proc = mp.Process(
|
||||||
|
target=launch_server,
|
||||||
|
args=(self.server_args, pipe_writer, model_overide_args),
|
||||||
|
)
|
||||||
proc.start()
|
proc.start()
|
||||||
pipe_writer.close()
|
pipe_writer.close()
|
||||||
self.pid = proc.pid
|
self.pid = proc.pid
|
||||||
@@ -265,7 +274,9 @@ class Runtime:
|
|||||||
|
|
||||||
if init_state != "init ok":
|
if init_state != "init ok":
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
raise RuntimeError("Initialization failed. Please see the error messages above.")
|
raise RuntimeError(
|
||||||
|
"Initialization failed. Please see the error messages above."
|
||||||
|
)
|
||||||
|
|
||||||
self.endpoint = RuntimeEndpoint(self.url)
|
self.endpoint = RuntimeEndpoint(self.url)
|
||||||
|
|
||||||
|
|||||||
@@ -80,10 +80,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.tokenizer_path,
|
default=ServerArgs.tokenizer_path,
|
||||||
help="The path of the tokenizer.",
|
help="The path of the tokenizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--host", type=str, default=ServerArgs.host,
|
parser.add_argument(
|
||||||
help="The host of the server.")
|
"--host", type=str, default=ServerArgs.host, help="The host of the server."
|
||||||
parser.add_argument("--port", type=int, default=ServerArgs.port,
|
)
|
||||||
help="The port of the server.")
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=ServerArgs.port, help="The port of the server."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--additional-ports",
|
"--additional-ports",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
try:
|
try:
|
||||||
s.bind(("", port))
|
s.bind(("", port))
|
||||||
|
s.listen(1) # Attempt to listen on the port
|
||||||
port_list.append(port)
|
port_list.append(port)
|
||||||
except socket.error:
|
except socket.error:
|
||||||
pass
|
pass # If any error occurs, this port is not usable
|
||||||
|
|
||||||
if len(port_list) == num:
|
if len(port_list) == num:
|
||||||
return port_list
|
return port_list
|
||||||
@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
|
|||||||
|
|
||||||
|
|
||||||
def is_multimodal_model(model):
|
def is_multimodal_model(model):
|
||||||
if isinstance(model, str):
|
|
||||||
return "llava" in model or "yi-vl" in model
|
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
|
|
||||||
|
if isinstance(model, str):
|
||||||
|
model = model.lower()
|
||||||
|
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
||||||
|
|
||||||
if isinstance(model, ModelConfig):
|
if isinstance(model, ModelConfig):
|
||||||
model_path = model.path.lower()
|
model_path = model.path.lower()
|
||||||
return "llava" in model_path or "yi-vl" in model_path
|
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
||||||
raise Exception("unrecognized type")
|
|
||||||
|
raise ValueError("unrecognized type")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_video_base64(video_base64):
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Decode the base64 string
|
||||||
|
video_bytes = base64.b64decode(video_base64)
|
||||||
|
|
||||||
|
# Placeholder for the start indices of each PNG image
|
||||||
|
img_starts = []
|
||||||
|
|
||||||
|
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
|
||||||
|
|
||||||
|
assert frame_format in [
|
||||||
|
"PNG",
|
||||||
|
"JPEG",
|
||||||
|
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
|
||||||
|
|
||||||
|
if frame_format == "PNG":
|
||||||
|
# Find each PNG start signature to isolate images
|
||||||
|
i = 0
|
||||||
|
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
|
||||||
|
# Check if we found the start of a PNG file
|
||||||
|
if (
|
||||||
|
video_bytes[i] == 0x89
|
||||||
|
and video_bytes[i + 1] == 0x50
|
||||||
|
and video_bytes[i + 2] == 0x4E
|
||||||
|
and video_bytes[i + 3] == 0x47
|
||||||
|
and video_bytes[i + 4] == 0x0D
|
||||||
|
and video_bytes[i + 5] == 0x0A
|
||||||
|
and video_bytes[i + 6] == 0x1A
|
||||||
|
and video_bytes[i + 7] == 0x0A
|
||||||
|
):
|
||||||
|
img_starts.append(i)
|
||||||
|
i += 8 # Skip the PNG signature
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
# Find each JPEG start (0xFFD8) to isolate images
|
||||||
|
i = 0
|
||||||
|
while (
|
||||||
|
i < len(video_bytes) - 1
|
||||||
|
): # Adjusted for the length of the JPEG SOI signature
|
||||||
|
# Check if we found the start of a JPEG file
|
||||||
|
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
|
||||||
|
img_starts.append(i)
|
||||||
|
# Move to the next byte to continue searching for the next image start
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for start_idx in img_starts:
|
||||||
|
# Assuming each image is back-to-back, the end of one image is the start of another
|
||||||
|
# The last image goes until the end of the byte string
|
||||||
|
end_idx = (
|
||||||
|
img_starts[img_starts.index(start_idx) + 1]
|
||||||
|
if img_starts.index(start_idx) + 1 < len(img_starts)
|
||||||
|
else len(video_bytes)
|
||||||
|
)
|
||||||
|
img_bytes = video_bytes[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Convert bytes to a PIL Image
|
||||||
|
img = Image.open(BytesIO(img_bytes))
|
||||||
|
|
||||||
|
# Convert PIL Image to a NumPy array
|
||||||
|
frame = np.array(img)
|
||||||
|
|
||||||
|
# Append the frame to the list of frames
|
||||||
|
frames.append(frame)
|
||||||
|
|
||||||
|
# Ensure there's at least one frame to avoid errors with np.stack
|
||||||
|
if frames:
|
||||||
|
return np.stack(frames, axis=0), img.size
|
||||||
|
else:
|
||||||
|
return np.array([]), (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
) # Return an empty array and size tuple if no frames were found
|
||||||
|
|
||||||
|
|
||||||
def load_image(image_file):
|
def load_image(image_file):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
image = None
|
image = image_size = None
|
||||||
|
|
||||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||||
@@ -289,10 +373,13 @@ def load_image(image_file):
|
|||||||
elif image_file.startswith("data:"):
|
elif image_file.startswith("data:"):
|
||||||
image_file = image_file.split(",")[1]
|
image_file = image_file.split(",")[1]
|
||||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||||
|
elif image_file.startswith("video:"):
|
||||||
|
image_file = image_file.replace("video:", "")
|
||||||
|
image, image_size = decode_video_base64(image_file)
|
||||||
else:
|
else:
|
||||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||||
|
|
||||||
return image
|
return image, image_size
|
||||||
|
|
||||||
|
|
||||||
def assert_pkg_version(pkg: str, min_version: str):
|
def assert_pkg_version(pkg: str, min_version: str):
|
||||||
@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
|
|||||||
f"is less than the minimum required version {min_version}"
|
f"is less than the minimum required version {min_version}"
|
||||||
)
|
)
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
|
raise Exception(
|
||||||
|
f"{pkg} with minimum required version {min_version} is not installed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
API_KEY_HEADER_NAME = "X-API-Key"
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ import torch
|
|||||||
from huggingface_hub import HfFileSystem, snapshot_download
|
from huggingface_hub import HfFileSystem, snapshot_download
|
||||||
from safetensors.torch import load_file, safe_open, save_file
|
from safetensors.torch import load_file, safe_open, save_file
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
from vllm.model_executor.layers.quantization import (
|
||||||
get_quantization_config)
|
QuantizationConfig,
|
||||||
|
get_quantization_config,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -32,17 +33,21 @@ logger = init_logger(__name__)
|
|||||||
# can share the same lock without error.
|
# can share the same lock without error.
|
||||||
# lock files in the temp directory will be automatically deleted when the
|
# lock files in the temp directory will be automatically deleted when the
|
||||||
# system reboots, so users will not complain about annoying lock files
|
# system reboots, so users will not complain about annoying lock files
|
||||||
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
|
temp_dir = (
|
||||||
'TEMP') or os.environ.get('TMP') or "/tmp/"
|
os.environ.get("TMPDIR")
|
||||||
|
or os.environ.get("TEMP")
|
||||||
|
or os.environ.get("TMP")
|
||||||
|
or "/tmp/"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def enable_hf_transfer():
|
def enable_hf_transfer():
|
||||||
"""automatically activates hf_transfer
|
"""automatically activates hf_transfer"""
|
||||||
"""
|
|
||||||
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
||||||
try:
|
try:
|
||||||
# enable hf hub transfer if available
|
# enable hf hub transfer if available
|
||||||
import hf_transfer # type: ignore # noqa
|
import hf_transfer # type: ignore # noqa
|
||||||
|
|
||||||
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@@ -65,8 +70,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
|||||||
# add hash to avoid conflict with old users' lock files
|
# add hash to avoid conflict with old users' lock files
|
||||||
lock_file_name = hash_name + model_name + ".lock"
|
lock_file_name = hash_name + model_name + ".lock"
|
||||||
# mode 0o666 is required for the filelock to be shared across users
|
# mode 0o666 is required for the filelock to be shared across users
|
||||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
|
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
|
||||||
mode=0o666)
|
|
||||||
return lock
|
return lock
|
||||||
|
|
||||||
|
|
||||||
@@ -104,10 +108,12 @@ def convert_bin_to_safetensor_file(
|
|||||||
sf_size = os.stat(sf_filename).st_size
|
sf_size = os.stat(sf_filename).st_size
|
||||||
pt_size = os.stat(pt_filename).st_size
|
pt_size = os.stat(pt_filename).st_size
|
||||||
if (sf_size - pt_size) / pt_size > 0.01:
|
if (sf_size - pt_size) / pt_size > 0.01:
|
||||||
raise RuntimeError(f"""The file size different is more than 1%:
|
raise RuntimeError(
|
||||||
|
f"""The file size different is more than 1%:
|
||||||
- {sf_filename}: {sf_size}
|
- {sf_filename}: {sf_size}
|
||||||
- {pt_filename}: {pt_size}
|
- {pt_filename}: {pt_size}
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# check if the tensors are the same
|
# check if the tensors are the same
|
||||||
reloaded = load_file(sf_filename)
|
reloaded = load_file(sf_filename)
|
||||||
@@ -122,8 +128,7 @@ def convert_bin_to_safetensor_file(
|
|||||||
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
||||||
quant_cls = get_quantization_config(model_config.quantization)
|
quant_cls = get_quantization_config(model_config.quantization)
|
||||||
# Read the quantization config from the HF model config, if available.
|
# Read the quantization config from the HF model config, if available.
|
||||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
|
||||||
None)
|
|
||||||
if hf_quant_config is not None:
|
if hf_quant_config is not None:
|
||||||
return quant_cls.from_config(hf_quant_config)
|
return quant_cls.from_config(hf_quant_config)
|
||||||
model_name_or_path = model_config.model
|
model_name_or_path = model_config.model
|
||||||
@@ -131,26 +136,29 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
|
|||||||
if not is_local:
|
if not is_local:
|
||||||
# Download the config files.
|
# Download the config files.
|
||||||
with get_lock(model_name_or_path, model_config.download_dir):
|
with get_lock(model_name_or_path, model_config.download_dir):
|
||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(
|
||||||
revision=model_config.revision,
|
model_name_or_path,
|
||||||
allow_patterns="*.json",
|
revision=model_config.revision,
|
||||||
cache_dir=model_config.download_dir,
|
allow_patterns="*.json",
|
||||||
tqdm_class=Disabledtqdm)
|
cache_dir=model_config.download_dir,
|
||||||
|
tqdm_class=Disabledtqdm,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||||
|
|
||||||
quant_config_files = [
|
quant_config_files = [
|
||||||
f for f in config_files if any(
|
f
|
||||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
for f in config_files
|
||||||
|
if any(f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||||
]
|
]
|
||||||
if len(quant_config_files) == 0:
|
if len(quant_config_files) == 0:
|
||||||
raise ValueError(
|
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
|
||||||
f"Cannot find the config file for {model_config.quantization}")
|
|
||||||
if len(quant_config_files) > 1:
|
if len(quant_config_files) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Found multiple config files for {model_config.quantization}: "
|
f"Found multiple config files for {model_config.quantization}: "
|
||||||
f"{quant_config_files}")
|
f"{quant_config_files}"
|
||||||
|
)
|
||||||
|
|
||||||
quant_config_file = quant_config_files[0]
|
quant_config_file = quant_config_files[0]
|
||||||
with open(quant_config_file, "r") as f:
|
with open(quant_config_file, "r") as f:
|
||||||
@@ -166,8 +174,7 @@ def prepare_hf_model_weights(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
) -> Tuple[str, List[str], bool]:
|
) -> Tuple[str, List[str], bool]:
|
||||||
# Download model weights from huggingface.
|
# Download model weights from huggingface.
|
||||||
is_local = os.path.isdir(model_name_or_path) \
|
is_local = os.path.isdir(model_name_or_path) and load_format != "tensorizer"
|
||||||
and load_format != "tensorizer"
|
|
||||||
use_safetensors = False
|
use_safetensors = False
|
||||||
# Some quantized models use .pt files for storing the weights.
|
# Some quantized models use .pt files for storing the weights.
|
||||||
if load_format == "auto":
|
if load_format == "auto":
|
||||||
@@ -203,11 +210,13 @@ def prepare_hf_model_weights(
|
|||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
# downloading the same model weights at the same time.
|
# downloading the same model weights at the same time.
|
||||||
with get_lock(model_name_or_path, cache_dir):
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(
|
||||||
allow_patterns=allow_patterns,
|
model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
allow_patterns=allow_patterns,
|
||||||
tqdm_class=Disabledtqdm,
|
cache_dir=cache_dir,
|
||||||
revision=revision)
|
tqdm_class=Disabledtqdm,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
hf_weights_files: List[str] = []
|
hf_weights_files: List[str] = []
|
||||||
@@ -228,16 +237,14 @@ def prepare_hf_model_weights(
|
|||||||
"scaler.pt",
|
"scaler.pt",
|
||||||
]
|
]
|
||||||
hf_weights_files = [
|
hf_weights_files = [
|
||||||
f for f in hf_weights_files
|
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
|
||||||
if not any(f.endswith(x) for x in blacklist)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if load_format == "tensorizer":
|
if load_format == "tensorizer":
|
||||||
return hf_folder, hf_weights_files, use_safetensors
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
if len(hf_weights_files) == 0:
|
if len(hf_weights_files) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`")
|
||||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
|
||||||
|
|
||||||
return hf_folder, hf_weights_files, use_safetensors
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
@@ -254,7 +261,8 @@ def hf_model_weights_iterator(
|
|||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
load_format=load_format,
|
load_format=load_format,
|
||||||
fall_back_to_pt=fall_back_to_pt,
|
fall_back_to_pt=fall_back_to_pt,
|
||||||
revision=revision)
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
if load_format == "npcache":
|
if load_format == "npcache":
|
||||||
# Currently np_cache only support *.bin checkpoints
|
# Currently np_cache only support *.bin checkpoints
|
||||||
@@ -289,22 +297,25 @@ def hf_model_weights_iterator(
|
|||||||
param = np.load(f)
|
param = np.load(f)
|
||||||
yield name, torch.from_numpy(param)
|
yield name, torch.from_numpy(param)
|
||||||
elif load_format == "tensorizer":
|
elif load_format == "tensorizer":
|
||||||
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
|
from vllm.model_executor.tensorizer_loader import (
|
||||||
open_stream,
|
TensorDeserializer,
|
||||||
tensorizer_warning)
|
open_stream,
|
||||||
|
tensorizer_warning,
|
||||||
|
)
|
||||||
|
|
||||||
tensorizer_args = load_format.params
|
tensorizer_args = load_format.params
|
||||||
tensorizer_warning(
|
tensorizer_warning(
|
||||||
"Deserializing HuggingFace models is not optimized for "
|
"Deserializing HuggingFace models is not optimized for "
|
||||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||||
"Consider deserializing a vLLM model instead for faster "
|
"Consider deserializing a vLLM model instead for faster "
|
||||||
"load times. See the examples/tensorize_vllm_model.py example "
|
"load times. See the examples/tensorize_vllm_model.py example "
|
||||||
"script for serializing vLLM models.")
|
"script for serializing vLLM models."
|
||||||
|
)
|
||||||
|
|
||||||
deserializer_args = tensorizer_args.deserializer_params
|
deserializer_args = tensorizer_args.deserializer_params
|
||||||
stream_params = tensorizer_args.stream_params
|
stream_params = tensorizer_args.stream_params
|
||||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||||
with TensorDeserializer(stream, **deserializer_args,
|
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
|
||||||
device="cpu") as state:
|
|
||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
yield name, param
|
yield name, param
|
||||||
del state
|
del state
|
||||||
@@ -324,8 +335,12 @@ def hf_model_weights_iterator(
|
|||||||
|
|
||||||
|
|
||||||
def kv_cache_scales_loader(
|
def kv_cache_scales_loader(
|
||||||
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
|
filename: str,
|
||||||
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
|
tp_rank: int,
|
||||||
|
tp_size: int,
|
||||||
|
num_hidden_layers: int,
|
||||||
|
model_type: Optional[str],
|
||||||
|
) -> Iterable[Tuple[int, float]]:
|
||||||
"""
|
"""
|
||||||
A simple utility to read in KV cache scaling factors that have been
|
A simple utility to read in KV cache scaling factors that have been
|
||||||
previously serialized to disk. Used by the model to populate the appropriate
|
previously serialized to disk. Used by the model to populate the appropriate
|
||||||
@@ -343,8 +358,7 @@ def kv_cache_scales_loader(
|
|||||||
"tp_size": tp_size,
|
"tp_size": tp_size,
|
||||||
}
|
}
|
||||||
schema_dct = json.load(f)
|
schema_dct = json.load(f)
|
||||||
schema = QuantParamSchema.model_validate(schema_dct,
|
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
||||||
context=context)
|
|
||||||
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
||||||
return layer_scales_map.items()
|
return layer_scales_map.items()
|
||||||
|
|
||||||
@@ -357,9 +371,11 @@ def kv_cache_scales_loader(
|
|||||||
# This section is reached if and only if any of the excepts are hit
|
# This section is reached if and only if any of the excepts are hit
|
||||||
# Return an empty iterable (list) => no KV cache scales are loaded
|
# Return an empty iterable (list) => no KV cache scales are loaded
|
||||||
# which ultimately defaults to 1.0 scales
|
# which ultimately defaults to 1.0 scales
|
||||||
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
|
logger.warning(
|
||||||
f"for all layers in TP rank {tp_rank} "
|
"Defaulting to KV cache scaling factors = 1.0 "
|
||||||
"as an error occurred during loading.")
|
f"for all layers in TP rank {tp_rank} "
|
||||||
|
"as an error occurred during loading."
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -378,8 +394,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def default_weight_loader(param: torch.Tensor,
|
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||||
loaded_weight: torch.Tensor) -> None:
|
|
||||||
"""Default weight loader."""
|
"""Default weight loader."""
|
||||||
assert param.size() == loaded_weight.size()
|
assert param.size() == loaded_weight.size()
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|||||||
@@ -2,13 +2,16 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from json import dumps
|
from json import dumps
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
@@ -110,6 +113,74 @@ def encode_image_base64(image_path):
|
|||||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_frame(frame):
|
||||||
|
import cv2 # pip install opencv-python-headless
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Convert the frame to RGB (OpenCV uses BGR by default)
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Convert the frame to PIL Image to easily convert to bytes
|
||||||
|
im_pil = Image.fromarray(frame)
|
||||||
|
|
||||||
|
# Convert to bytes
|
||||||
|
buffered = BytesIO()
|
||||||
|
|
||||||
|
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
|
||||||
|
|
||||||
|
im_pil.save(buffered, format="PNG")
|
||||||
|
|
||||||
|
frame_bytes = buffered.getvalue()
|
||||||
|
|
||||||
|
# Return the bytes of the frame
|
||||||
|
return frame_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video_base64(video_path, num_frames=16):
|
||||||
|
import cv2
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise IOError(f"Could not open video file:{video_path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
print(f"target_frames: {num_frames}")
|
||||||
|
|
||||||
|
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for i in range(total_frames):
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if ret:
|
||||||
|
frames.append(frame)
|
||||||
|
else:
|
||||||
|
# Handle the case where the frame could not be read
|
||||||
|
# print(f"Warning: Could not read frame at index {i}.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
# Safely select frames based on frame_indices, avoiding IndexError
|
||||||
|
frames = [frames[i] for i in frame_indices if i < len(frames)]
|
||||||
|
|
||||||
|
# If there are not enough frames, duplicate the last frame until we reach the target
|
||||||
|
while len(frames) < num_frames:
|
||||||
|
frames.append(frames[-1])
|
||||||
|
|
||||||
|
# Use ThreadPoolExecutor to process and encode frames in parallel
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
encoded_frames = list(executor.map(encode_frame, frames))
|
||||||
|
|
||||||
|
# encoded_frames = list(map(encode_frame, frames))
|
||||||
|
|
||||||
|
# Concatenate all frames bytes
|
||||||
|
video_bytes = b"".join(encoded_frames)
|
||||||
|
|
||||||
|
# Encode the concatenated bytes to base64
|
||||||
|
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
return video_base64
|
||||||
|
|
||||||
|
|
||||||
def _is_chinese_char(cp):
|
def _is_chinese_char(cp):
|
||||||
"""Checks whether CP is the codepoint of a CJK character."""
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
|||||||
Reference in New Issue
Block a user