diff --git a/.gitignore b/.gitignore index 49e810d4e..10e602e83 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,7 @@ tmp*.txt # Plots *.png *.pdf + +# personnal +work_dirs/ +*.csv diff --git a/README.md b/README.md index a699b9c2a..a0df39622 100644 --- a/README.md +++ b/README.md @@ -410,4 +410,4 @@ https://github.com/sgl-project/sglang/issues/157 } ``` -We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql). +We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql). \ No newline at end of file diff --git a/examples/quick_start/srt_example_llava.py b/examples/quick_start/srt_example_llava.py index 97812ea9f..27685b1d2 100644 --- a/examples/quick_start/srt_example_llava.py +++ b/examples/quick_start/srt_example_llava.py @@ -14,7 +14,7 @@ def single(): state = image_qa.run( image_path="images/cat.jpeg", question="What is this?", - max_new_tokens=64) + max_new_tokens=128) print(state["answer"], "\n") @@ -36,7 +36,7 @@ def batch(): {"image_path": "images/cat.jpeg", "question":"What is this?"}, {"image_path": "images/dog.jpeg", "question":"What is this?"}, ], - max_new_tokens=64, + max_new_tokens=128, ) for s in states: print(s["answer"], "\n") diff --git a/examples/usage/llava_video/srt_example_llava_v.py b/examples/usage/llava_video/srt_example_llava_v.py new file mode 100644 index 000000000..e18a81ebb --- /dev/null +++ b/examples/usage/llava_video/srt_example_llava_v.py @@ -0,0 +1,208 @@ +""" +Usage: python3 srt_example_llava.py +""" + +import sglang as sgl +import os +import csv +import time +import argparse + +@sgl.function +def video_qa(s, num_frames, video_path, question): + s += sgl.user(sgl.video(video_path,num_frames) + question) + s += sgl.assistant(sgl.gen("answer")) + + +def single(path, num_frames=16): + state = video_qa.run( + num_frames=num_frames, + video_path=path, + question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes", + temperature=0.0, + max_new_tokens=1024, + ) + print(state["answer"], "\n") + + + +def split_into_chunks(lst, num_chunks): + """Split a list into a specified number of chunks.""" + # Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible. + chunk_size = len(lst) // num_chunks + + if chunk_size == 0: + chunk_size = len(lst) + # Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible. + chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] + # Ensure we have exactly num_chunks chunks, even if some are empty + chunks.extend([[] for _ in range(num_chunks - len(chunks))]) + return chunks + + +def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir): + csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv" + with open(csv_filename, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['video_name', 'answer']) + for video_path, state in zip(batch_video_files, states): + video_name = os.path.basename(video_path) + writer.writerow([video_name, state["answer"]]) + +def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir): + final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv" + with open(final_csv_filename, 'w', newline='') as final_csvfile: + writer = csv.writer(final_csvfile) + writer.writerow(['video_name', 'answer']) + for batch_idx in range(num_batches): + batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv" + with open(batch_csv_filename, 'r') as batch_csvfile: + reader = csv.reader(batch_csvfile) + next(reader) # Skip header row + for row in reader: + writer.writerow(row) + os.remove(batch_csv_filename) + +def find_video_files(video_dir): + # Check if the video_dir is actually a file + if os.path.isfile(video_dir): + # If it's a file, return it as a single-element list + return [video_dir] + + # Original logic to find video files in a directory + video_files = [] + for root, dirs, files in os.walk(video_dir): + for file in files: + if file.endswith(('.mp4', '.avi', '.mov')): + video_files.append(os.path.join(root, file)) + return video_files + +def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64): + video_files = find_video_files(video_dir) + chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk] + num_batches = 0 + + for i in range(0, len(chunked_video_files), batch_size): + batch_video_files = chunked_video_files[i:i + batch_size] + print(f"Processing batch of {len(batch_video_files)} video(s)...") + + if not batch_video_files: + print("No video files found in the specified directory.") + return + + batch_input = [ + { + "num_frames": num_frames, + "video_path": video_path, + "question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.", + } for video_path in batch_video_files + ] + + start_time = time.time() + states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2) + total_time = time.time() - start_time + average_time = total_time / len(batch_video_files) + print(f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds") + + save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir) + num_batches += 1 + + compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir) + + +if __name__ == "__main__": + + # Create the parser + parser = argparse.ArgumentParser(description='Run video processing with specified port.') + + # Add an argument for the port + parser.add_argument('--port', type=int, default=30000, help='The master port for distributed serving.') + parser.add_argument('--chunk-idx', type=int, default=0, help='The index of the chunk to process.') + parser.add_argument('--num-chunks', type=int, default=8, help='The number of chunks to process.') + parser.add_argument('--save-dir', type=str, default="./work_dirs/llava_video", help='The directory to save the processed video files.') + parser.add_argument('--video-dir', type=str, default="./videos/Q98Z4OTh8RwmDonc.mp4", help='The directory or path for the processed video files.') + parser.add_argument('--model-path', type=str, default="lmms-lab/LLaVA-NeXT-Video-7B", help='The model path for the video processing.') + parser.add_argument('--num-frames', type=int, default=16, help='The number of frames to process in each video.' ) + parser.add_argument("--mm_spatial_pool_stride", type=int, default=2) + + # Parse the arguments + args = parser.parse_args() + + cur_port = args.port + + cur_chunk = args.chunk_idx + + num_chunks = args.num_chunks + + num_frames = args.num_frames + + if "34b" in args.model_path.lower(): + tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer" + elif "7b" in args.model_path.lower(): + tokenizer_path = "llava-hf/llava-1.5-7b-hf" + else: + print("Invalid model path. Please specify a valid model path.") + exit() + + model_overide_args = {} + + model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride + model_overide_args["architectures"] = ["LlavaVidForCausalLM"] + model_overide_args["num_frames"] = args.num_frames + model_overide_args["model_type"] = "llava" + + if "34b" in args.model_path.lower(): + model_overide_args["image_token_index"] = 64002 + + + if args.num_frames == 32: + model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_overide_args["max_sequence_length"] = 4096 * 2 + model_overide_args["tokenizer_model_max_length"] = 4096 * 2 + elif args.num_frames < 32: + pass + else: + print("The maximum number of frames to process is 32. Please specify a valid number of frames.") + exit() + + + runtime = sgl.Runtime( + model_path=args.model_path, #"liuhaotian/llava-v1.6-vicuna-7b", + tokenizer_path=tokenizer_path, + port=cur_port, + additional_ports=[cur_port+1,cur_port+2,cur_port+3,cur_port+4], + model_overide_args=model_overide_args, + tp_size=1 + ) + sgl.set_default_backend(runtime) + print(f"chat template: {runtime.endpoint.chat_template.name}") + + + # Run a single request + # try: + print("\n========== single ==========\n") + root = args.video_dir + if os.path.isfile(root): + video_files = [root] + else: + video_files = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.mp4', '.avi', '.mov'))] # Add more extensions if needed + start_time = time.time() # Start time for processing a single video + for cur_video in video_files[:1]: + print(cur_video) + single(cur_video, num_frames) + end_time = time.time() # End time for processing a single video + total_time = end_time - start_time + average_time = total_time / len(video_files) # Calculate the average processing time + print(f"Average processing time per video: {average_time:.2f} seconds") + runtime.shutdown() + # except Exception as e: + # print(e) + runtime.shutdown() + + + # # # Run a batch of requests + # print("\n========== batch ==========\n") + # if not os.path.exists(args.save_dir): + # os.makedirs(args.save_dir) + # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks) + # runtime.shutdown() \ No newline at end of file diff --git a/examples/usage/llava_video/srt_example_llava_v.sh b/examples/usage/llava_video/srt_example_llava_v.sh new file mode 100644 index 000000000..56566de2a --- /dev/null +++ b/examples/usage/llava_video/srt_example_llava_v.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +##### USAGE ##### +# - First node: +# ```sh +# bash examples/quick_start/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO +# ``` +# - Second node: +# ```sh +# bash examples/quick_start/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO +# ``` +# - The K node: +# ```sh +# bash examples/quick_start/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO +# ``` + + +# Replace `K`, `YOUR_VIDEO_PATH`, `YOUR_MODEL_PATH`, and `FRAMES_PER_VIDEO` with your specific details. +CURRENT_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +echo ${CURRENT_ROOT} + +cd ${CURRENT_ROOT} + +export PYTHONWARNINGS=ignore + +START_TIME=$(date +%s) # Capture start time + +NUM_NODES=$1 + +CUR_NODES_IDX=$2 + +VIDEO_DIR=$3 + +MODEL_PATH=$4 + +NUM_FRAMES=$5 + + +# FRAME_FORMAT=$6 + +# FRAME_FORMAT=$(echo $FRAME_FORMAT | tr '[:lower:]' '[:upper:]') + +# # Check if FRAME_FORMAT is either JPEG or PNG +# if [[ "$FRAME_FORMAT" != "JPEG" && "$FRAME_FORMAT" != "PNG" ]]; then +# echo "Error: FRAME_FORMAT must be either JPEG or PNG." +# exit 1 +# fi + +# export TARGET_FRAMES=$TARGET_FRAMES + +echo "Each video you will sample $NUM_FRAMES frames" + +# export FRAME_FORMAT=$FRAME_FORMAT + +# echo "The frame format is $FRAME_FORMAT" + +# Assuming GPULIST is a bash array containing your GPUs +GPULIST=(0 1 2 3 4 5 6 7) +LOCAL_CHUNKS=${#GPULIST[@]} + +echo "Number of GPUs in GPULIST: $LOCAL_CHUNKS" + +ALL_CHUNKS=$((NUM_NODES * LOCAL_CHUNKS)) + +# Calculate GPUs per chunk +GPUS_PER_CHUNK=8 + +echo $GPUS_PER_CHUNK + +for IDX in $(seq 1 $LOCAL_CHUNKS); do + ( + START=$(((IDX-1) * GPUS_PER_CHUNK)) + LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index + + CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH}) + + # Convert the chunk GPUs array to a comma-separated string + CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}") + + LOCAL_IDX=$((CUR_NODES_IDX * LOCAL_CHUNKS + IDX)) + + echo "Chunk $(($LOCAL_IDX - 1)) will run on GPUs $CHUNK_GPUS_STR" + + # Calculate the port for this chunk. Ensure it's incremented by 5 for each chunk. + PORT=$((10000 + RANDOM % 55536)) + + MAX_RETRIES=10 + RETRY_COUNT=0 + COMMAND_STATUS=1 # Initialize as failed + + while [ $RETRY_COUNT -lt $MAX_RETRIES ] && [ $COMMAND_STATUS -ne 0 ]; do + echo "Running chunk $(($LOCAL_IDX - 1)) on GPUs $CHUNK_GPUS_STR with port $PORT. Attempt $(($RETRY_COUNT + 1))" + +#!/bin/bash + CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 examples/usage/llava_video/srt_example_llava_v.py \ + --port $PORT \ + --num-chunks $ALL_CHUNKS \ + --chunk-idx $(($LOCAL_IDX - 1)) \ + --save-dir work_dirs/llava_next_video_inference_results \ + --video-dir $VIDEO_DIR \ + --model-path $MODEL_PATH \ + --num-frames $NUM_FRAMES #& + + wait $! # Wait for the process to finish and capture its exit status + COMMAND_STATUS=$? + + if [ $COMMAND_STATUS -ne 0 ]; then + echo "Execution failed for chunk $(($LOCAL_IDX - 1)), attempt $(($RETRY_COUNT + 1)). Retrying..." + RETRY_COUNT=$(($RETRY_COUNT + 1)) + sleep 180 # Wait a bit before retrying + else + echo "Execution succeeded for chunk $(($LOCAL_IDX - 1))." + fi + done + + if [ $COMMAND_STATUS -ne 0 ]; then + echo "Execution failed for chunk $(($LOCAL_IDX - 1)) after $MAX_RETRIES attempts." + fi + ) #& + sleep 2 # Slight delay to stagger the start times +done + +wait + +cat work_dirs/llava_next_video_inference_results/final_results_chunk_*.csv > work_dirs/llava_next_video_inference_results/final_results_node_${CUR_NODES_IDX}.csv + +END_TIME=$(date +%s) # Capture end time +ELAPSED_TIME=$(($END_TIME - $START_TIME)) +echo "Total execution time: $ELAPSED_TIME seconds." \ No newline at end of file diff --git a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 b/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 new file mode 100644 index 000000000..32d912dbf Binary files /dev/null and b/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 differ diff --git a/python/pyproject.toml b/python/pyproject.toml index 06a0c5a8c..734e9275d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", - "zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "outlines>=0.0.27", "packaging"] + "zmq", "vllm>=0.4.2", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"] openai = ["openai>=1.0", "numpy", "tiktoken"] anthropic = ["anthropic>=0.20.0", "numpy"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] @@ -33,4 +33,4 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] [tool.wheel] -exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] +exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"] \ No newline at end of file diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 741d8cc72..f2ddfc641 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -19,6 +19,7 @@ from sglang.api import ( user, user_begin, user_end, + video, ) # SGL Backends @@ -46,6 +47,7 @@ __all__ = [ "gen_int", "gen_string", "image", + "video", "select", "system", "user", diff --git a/python/sglang/api.py b/python/sglang/api.py index f2b92a960..4448333a6 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -15,6 +15,7 @@ from sglang.lang.ir import ( SglRoleBegin, SglRoleEnd, SglSelect, + SglVideo, ) @@ -151,6 +152,10 @@ def image(expr: SglExpr): return SglImage(expr) +def video(path: str, num_frames: int): + return SglVideo(path, num_frames) + + def select( name: Optional[str] = None, choices: List[str] = None, diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 45f5825ce..9b0f10232 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -259,6 +259,8 @@ def match_vicuna(model_path: str): return get_chat_template("vicuna_v1.1") if "llava-v1.5" in model_path.lower(): return get_chat_template("vicuna_v1.1") + if "llava-next-video-7b" in model_path.lower(): + return get_chat_template("vicuna_v1.1") @register_chat_template_matching_function @@ -283,19 +285,24 @@ def match_llama3_instruct(model_path: str): @register_chat_template_matching_function def match_chat_ml(model_path: str): + # import pdb;pdb.set_trace() model_path = model_path.lower() if "tinyllama" in model_path: return get_chat_template("chatml") if "qwen" in model_path and "chat" in model_path: return get_chat_template("chatml") - if "llava-v1.6-34b" in model_path: + if ( + "llava-v1.6-34b" in model_path + or "llava-v1.6-yi-34b" in model_path + or "llava-next-video-34b" in model_path + ): return get_chat_template("chatml-llava") @register_chat_template_matching_function def match_chat_yi(model_path: str): model_path = model_path.lower() - if "yi" in model_path: + if "yi" in model_path and "llava" not in model_path: return get_chat_template("yi") diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 5bc51928c..e33a9760b 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -28,8 +28,9 @@ from sglang.lang.ir import ( SglVariable, SglVarScopeBegin, SglVarScopeEnd, + SglVideo, ) -from sglang.utils import encode_image_base64, get_exception_traceback +from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback def run_internal(state, program, func_args, func_kwargs, sync): @@ -361,6 +362,8 @@ class StreamExecutor: self._execute_role_end(other) elif isinstance(other, SglImage): self._execute_image(other) + elif isinstance(other, SglVideo): + self._execute_video(other) elif isinstance(other, SglVariable): self._execute_variable(other) elif isinstance(other, SglVarScopeBegin): @@ -397,6 +400,16 @@ class StreamExecutor: self.cur_images.append((path, base64_data)) self.text_ += self.chat_template.image_token + def _execute_video(self, expr: SglVideo): + path = expr.path + num_frames = expr.num_frames + + base64_data = encode_video_base64(path, num_frames) + + self.images_.append((path, base64_data)) + self.cur_images.append((path, base64_data)) + self.text_ += self.chat_template.image_token + # if global_config.eager_fill_image: # self.backend.fill_image(self) diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index eaf92070c..3506e6ba3 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -330,6 +330,15 @@ class SglImage(SglExpr): return f"SglImage({self.path})" +class SglVideo(SglExpr): + def __init__(self, path, num_frames): + self.path = path + self.num_frames = num_frames + + def __repr__(self) -> str: + return f"SglVideo({self.path}, {self.num_frames})" + + class SglGen(SglExpr): def __init__( self, diff --git a/python/sglang/lang/tracer.py b/python/sglang/lang/tracer.py index b506c44e1..53f772163 100644 --- a/python/sglang/lang/tracer.py +++ b/python/sglang/lang/tracer.py @@ -110,7 +110,7 @@ class TracerProgramState(ProgramState): ################################## def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None): - assert (size >= 1) + assert size >= 1 if self.only_trace_prefix: raise StopTracing() diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 40db38b9a..9d63a2aed 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -2,11 +2,10 @@ import argparse from sglang.srt.server import ServerArgs, launch_server - if __name__ == "__main__": parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - launch_server(server_args, None) \ No newline at end of file + launch_server(server_args, None) diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py new file mode 100644 index 000000000..564ead5e4 --- /dev/null +++ b/python/sglang/launch_server_llavavid.py @@ -0,0 +1,31 @@ +import argparse +import multiprocessing as mp + +from sglang.srt.server import ServerArgs, launch_server + +if __name__ == "__main__": + + model_overide_args = {} + + model_overide_args["mm_spatial_pool_stride"] = 2 + model_overide_args["architectures"] = ["LlavaVidForCausalLM"] + model_overide_args["num_frames"] = 16 + model_overide_args["model_type"] = "llavavid" + if model_overide_args["num_frames"] == 32: + model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_overide_args["max_sequence_length"] = 4096 * 2 + model_overide_args["tokenizer_model_max_length"] = 4096 * 2 + model_overide_args["model_max_length"] = 4096 * 2 + + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + args = parser.parse_args() + + if "34b" in args.model_path.lower(): + model_overide_args["image_token_index"] = 64002 + + server_args = ServerArgs.from_cli_args(args) + + pipe_reader, pipe_writer = mp.Pipe(duplex=False) + + launch_server(server_args, pipe_writer, model_overide_args) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 114ae5e1e..9d2f917d8 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -30,10 +30,17 @@ def get_config_json(model_path: str): return config -def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None): +def get_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, + model_overide_args: Optional[dict] = None, +): config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision ) + if model_overide_args: + config.update(model_overide_args) return config diff --git a/python/sglang/srt/managers/router/manager.py b/python/sglang/srt/managers/router/manager.py index 2ee9f3147..66adc2e59 100644 --- a/python/sglang/srt/managers/router/manager.py +++ b/python/sglang/srt/managers/router/manager.py @@ -60,9 +60,7 @@ class RouterManager: def start_router_process( - server_args: ServerArgs, - port_args: PortArgs, - pipe_writer, + server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args ): logging.basicConfig( level=getattr(logging, server_args.log_level.upper()), @@ -70,7 +68,7 @@ def start_router_process( ) try: - model_client = ModelRpcClient(server_args, port_args) + model_client = ModelRpcClient(server_args, port_args, model_overide_args) router = RouterManager(model_client, port_args) except Exception: pipe_writer.send(get_exception_traceback()) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 46b2c0c61..d94d3997c 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -4,12 +4,13 @@ import multiprocessing import time import warnings from concurrent.futures import ThreadPoolExecutor -from typing import List +from typing import Any, Dict, List, Optional, Tuple, Union import rpyc import torch from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer + try: from vllm.logger import _default_handler as vllm_default_logger except ImportError: @@ -48,6 +49,7 @@ class ModelRpcServer: tp_rank: int, server_args: ServerArgs, port_args: PortArgs, + model_overide_args: Optional[dict] = None, ): server_args, port_args = [obtain(x) for x in [server_args, port_args]] @@ -62,6 +64,7 @@ class ModelRpcServer: server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length, + model_overide_args=model_overide_args, ) # For model end global settings @@ -673,13 +676,15 @@ class ModelRpcService(rpyc.Service): class ModelRpcClient: - def __init__(self, server_args: ServerArgs, port_args: PortArgs): + def __init__( + self, server_args: ServerArgs, port_args: PortArgs, model_overide_args + ): tp_size = server_args.tp_size if tp_size == 1: # Init model self.model_server = ModelRpcService().exposed_ModelRpcServer( - 0, server_args, port_args + 0, server_args, port_args, model_overide_args ) # Wrap functions @@ -700,7 +705,7 @@ class ModelRpcClient: # Init model def init_model(i): return self.remote_services[i].ModelRpcServer( - i, server_args, port_args + i, server_args, port_args, model_overide_args ) self.model_servers = executor.map(init_model, range(tp_size)) @@ -723,7 +728,11 @@ def _init_service(port): t = ThreadedServer( ModelRpcService(), port=port, - protocol_config={"allow_pickle": True, "sync_request_timeout": 1800}, + protocol_config={ + "allow_public_attrs": True, + "allow_pickle": True, + "sync_request_timeout": 1800, + }, ) t.start() @@ -739,7 +748,11 @@ def start_model_process(port): con = rpyc.connect( "localhost", port, - config={"allow_pickle": True, "sync_request_timeout": 1800}, + config={ + "allow_public_attrs": True, + "allow_pickle": True, + "sync_request_timeout": 1800, + }, ) break except ConnectionRefusedError: diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 40a32369a..48541932b 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -9,11 +9,11 @@ from typing import List import numpy as np import torch +from vllm.distributed import initialize_model_parallel from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.distributed import initialize_model_parallel from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool @@ -143,7 +143,7 @@ class InputMetadata: self.kv_last_page_len, self.model_runner.model_config.num_attention_heads // tp_size, self.model_runner.model_config.num_key_value_heads // tp_size, - self.model_runner.model_config.head_dim + self.model_runner.model_config.head_dim, ] self.prefill_wrapper.begin_forward(*args) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 81f009ce1..282466ea2 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -60,21 +60,29 @@ def get_pixel_values( ): try: processor = processor or global_processor - image = load_image(image_data) - image_hash = hash(image_data) - if image_aspect_ratio == "pad": - image = expand2square( - image, tuple(int(x * 255) for x in processor.image_processor.image_mean) - ) - pixel_values = processor.image_processor(image)["pixel_values"][0] - elif image_aspect_ratio == "anyres": - pixel_values = process_anyres_image( - image, processor.image_processor, image_grid_pinpoints - ) + image, image_size = load_image(image_data) + if image_size != None: + image_hash = hash(image_data) + pixel_values = processor.image_processor(image)["pixel_values"] + for _ in range(len(pixel_values)): + pixel_values[_] = pixel_values[_].astype(np.float16) + pixel_values = np.stack(pixel_values, axis=0) + return pixel_values, image_hash, image_size else: - pixel_values = processor.image_processor(image)["pixel_values"][0] - pixel_values = pixel_values.astype(np.float16) - return pixel_values, image_hash, image.size + image_hash = hash(image_data) + if image_aspect_ratio == "pad": + image = expand2square( + image, tuple(int(x * 255) for x in processor.image_processor.image_mean) + ) + pixel_values = processor.image_processor(image)["pixel_values"][0] + elif image_aspect_ratio == "anyres": + pixel_values = process_anyres_image( + image, processor.image_processor, image_grid_pinpoints + ) + else: + pixel_values = processor.image_processor(image)["pixel_values"][0] + pixel_values = pixel_values.astype(np.float16) + return pixel_values, image_hash, image.size except Exception: print("Exception in TokenizerManager:\n" + get_exception_traceback()) @@ -84,6 +92,7 @@ class TokenizerManager: self, server_args: ServerArgs, port_args: PortArgs, + model_overide_args: dict = None, ): self.server_args = server_args @@ -96,7 +105,9 @@ class TokenizerManager: self.model_path = server_args.model_path self.hf_config = get_config( - self.model_path, trust_remote_code=server_args.trust_remote_code + self.model_path, + trust_remote_code=server_args.trust_remote_code, + model_overide_args=model_overide_args, ) self.context_len = get_context_length(self.hf_config) diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index 55961ad30..af4a6c103 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -10,12 +10,16 @@ class ModelConfig: trust_remote_code: bool = True, revision: Optional[str] = None, context_length: Optional[int] = None, + model_overide_args: Optional[dict] = None, ) -> None: self.path = path self.trust_remote_code = trust_remote_code self.revision = revision self.hf_config = get_config(self.path, trust_remote_code, revision) + if model_overide_args is not None: + self.hf_config.update(model_overide_args) + if context_length is not None: self.context_len = context_length else: diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 631e8c7f4..b485db264 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -27,29 +27,25 @@ import torch.utils.checkpoint from torch import nn from torch.nn.parameter import Parameter from transformers import PretrainedConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) from vllm.model_executor.utils import set_weight_attrs -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator @torch.compile diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 4b30a1f57..e4bce189d 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -5,37 +5,31 @@ from typing import Optional import torch import torch.nn as nn +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) -from vllm.distributed import ( - tensor_model_parallel_all_reduce, -) -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) from vllm.model_executor.utils import set_weight_attrs -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.models.dbrx_config import DbrxConfig +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class DbrxRouter(nn.Module): @@ -291,7 +285,9 @@ class DbrxBlock(nn.Module): quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config) + self.norm_attn_norm = DbrxFusedNormAttention( + config, layer_id, quant_config=quant_config + ) self.ffn = DbrxExperts(config, quant_config=quant_config) def forward( @@ -322,7 +318,10 @@ class DbrxModel(nn.Module): config.d_model, ) self.blocks = nn.ModuleList( - [DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)] + [ + DbrxBlock(config, i, quant_config=quant_config) + for i in range(config.n_layers) + ] ) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 4b0b00479..712af0981 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -7,6 +7,7 @@ import torch from torch import nn from transformers import PretrainedConfig from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.distributed import ( - get_tensor_model_parallel_world_size, -) -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class GemmaMLP(nn.Module): @@ -46,7 +40,10 @@ class GemmaMLP(nn.Module): quant_config=quant_config, ) self.down_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, quant_config=quant_config, + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, ) self.act_fn = GeluAndMul() diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 26c412871..fde8ebb06 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple import torch from torch import nn from transformers import LlamaConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.distributed import ( - get_tensor_model_parallel_world_size, -) -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlamaMLP(nn.Module): @@ -49,7 +43,10 @@ class LlamaMLP(nn.Module): quant_config=quant_config, ) self.down_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, quant_config=quant_config, + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, ) if hidden_act != "silu": raise ValueError( diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 232aee1d3..abce92061 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -7,12 +7,7 @@ import torch from torch import nn from transformers import CLIPVisionModel, LlavaConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.model_runner import InputMetadata @@ -22,6 +17,7 @@ from sglang.srt.mm_utils import ( unpad_image_shape, ) from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlavaLlamaForCausalLM(nn.Module): diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py new file mode 100644 index 000000000..e4205e509 --- /dev/null +++ b/python/sglang/srt/models/llavavid.py @@ -0,0 +1,307 @@ +"""Inference-only LLaVa video model compatible with HuggingFace weights.""" + +import os +from typing import List, Optional + +import numpy as np +import torch +from torch import nn +from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig +from transformers.models.llava.modeling_llava import LlavaMultiModalProjector +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +from sglang.srt.managers.router.infer_batch import ForwardMode +from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.mm_utils import ( + get_anyres_image_grid_shape, + unpad_image, + unpad_image_shape, +) +from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator + + +class LlavaVidForCausalLM(nn.Module): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.vision_tower = None + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2) + self.resampler = nn.AvgPool2d( + kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride + ) + self.language_model = LlamaForCausalLM(config, quant_config=quant_config) + self.num_frames = getattr(self.config, "num_frames", 16) + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + new_image_feature_len = self.image_feature_len + # now only support spatial_unpad + anyres + # if self.mm_patch_merge_type.startswith("spatial"): + # height = width = self.num_patches_per_side + # if pt_shape[0] > 1: + # if self.image_aspect_ratio == "anyres": + # num_patch_width, num_patch_height = get_anyres_image_grid_shape( + # image_size, + # self.image_grid_pinpoints, + # self.vision_tower.config.image_size, + # ) + # if "unpad" in self.mm_patch_merge_type: + # h = num_patch_height * height + # w = num_patch_width * width + # new_h, new_w = unpad_image_shape(h, w, image_size) + # new_image_feature_len += new_h * (new_w + 1) + + pad_ids = pad_value * ( + (new_image_feature_len + len(pad_value)) // len(pad_value) + ) + # print(input_ids) + offset = input_ids.index(self.config.image_token_index) + # old_len + pad_len - 1, because we need to remove image_token_id + new_input_ids = ( + input_ids[:offset] + + pad_ids[:new_image_feature_len] + + input_ids[offset + 1 :] + ) + return new_input_ids, offset + + def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. + + selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] + if self.vision_feature_select_strategy in ["default", "patch"]: + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + + height = width = self.num_patches_per_side + num_of_frames = selected_image_feature.shape[0] + selected_image_feature = selected_image_feature.view( + num_of_frames, height, width, -1 + ) + selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous() + selected_image_feature = ( + self.resampler(selected_image_feature) + .flatten(2) + .transpose(1, 2) + .contiguous() + ) + + image_features = self.multi_modal_projector(selected_image_feature) + + return image_features + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + pixel_values: Optional[List[Optional[np.array]]] = None, + image_sizes: Optional[List[List[int]]] = None, + image_offsets: Optional[List[int]] = None, + ) -> torch.Tensor: + if input_metadata.forward_mode == ForwardMode.EXTEND: + bs = input_metadata.batch_size + + # Embed text input + input_embeds = self.language_model.model.embed_tokens(input_ids) + + # Embed vision input + need_vision = ( + (positions[input_metadata.extend_start_loc] < self.image_feature_len) + .cpu() + .numpy() + ) + # FIXME: We need to substract the length of the system prompt + has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) + need_vision = need_vision & has_pixel + + if need_vision.any(): + pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] + image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] + + ########## Encode Image ######## + + if pixel_values[0].ndim == 4: + # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images + np.concatenate(pixel_values, axis=0) + # ndim=4 + concat_images = torch.tensor( + np.concatenate(pixel_values, axis=0), + device=self.vision_tower.device, + ) + # image_features = self.encode_images(concat_images) + # split_sizes = [image.shape[0] for image in pixel_values] + # image_features = torch.split(image_features, split_sizes, dim=0) + image_features = self.encode_images( + concat_images + ) # , prompts)#, image_counts, long_video=long_video) + split_sizes = [image.shape[0] for image in pixel_values] + image_features = torch.split(image_features, split_sizes, dim=0) + + # hd image_features: BS, num_patch, 576, 4096 + else: + # normal pixel: BS, C=3, H=336, W=336 + pixel_values = torch.tensor( + np.array(pixel_values), device=self.vision_tower.device + ) + image_features = self.encode_images(pixel_values) + # image_features: BS, 576, 4096 + + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + new_image_features.append(image_feature.flatten(0, 1)) + image_features = new_image_features + + extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + pt = 0 + for i in range(bs): + if not need_vision[i]: + continue + + start_idx = extend_start_loc_cpu[i] + pad_len, pad_dim = image_features[pt].shape # 576, 4096 + dim = input_embeds.shape[1] + assert ( + pad_dim == dim + ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) + # Fill in the placeholder for the image + try: + input_embeds[ + start_idx + + image_offsets[i] : start_idx + + image_offsets[i] + + pad_len + ] = image_features[pt] + except RuntimeError as e: + print(f"RuntimeError in llava image encoding: {e}") + print(input_embeds.shape) + print(start_idx, image_offsets[i]) + pt += 1 + + return self.language_model( + input_ids, positions, input_metadata, input_embeds=input_embeds + ) + elif input_metadata.forward_mode == ForwardMode.DECODE: + return self.language_model(input_ids, positions, input_metadata) + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + # load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + vision_path = self.config.mm_vision_tower + self.vision_tower = CLIPVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + self.vision_tower.eval() + + self.vision_feature_layer = self.config.mm_vision_select_layer + self.vision_feature_select_strategy = self.config.mm_vision_select_feature + self.image_size = self.vision_tower.config.image_size + self.patch_size = self.vision_tower.config.patch_size + + self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) + + print(f"target_frames: {self.num_frames}") + self.image_feature_len = self.num_frames * int( + (self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2 + ) + if self.vision_feature_select_strategy == "patch": + pass + elif self.vision_feature_select_strategy == "cls_patch": + self.image_feature_len += 1 + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + + # load mm_projector + projector_weights = { + "model.mm_projector.0": "multi_modal_projector.linear_1", + "model.mm_projector.2": "multi_modal_projector.linear_2", + "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1", + "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2", + "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). + } + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + # FIXME: why projector weights read two times? + if "projector" in name or "vision_tower" in name: + for weight_name, param_name in projector_weights.items(): + if weight_name in name: + name = name.replace(weight_name, param_name) + if name in params_dict: + param = params_dict[name] + else: + print(f"Warning: {name} not found in the model") + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + # load language model + self.language_model.load_weights( + model_name_or_path, cache_dir, load_format, revision + ) + + monkey_path_clip_vision_embed_forward() + + @property + def num_patches_per_side(self): + return self.image_size // self.patch_size + + +first_call = True + + +def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + + # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. + global first_call + if first_call: + self.patch_embedding.cpu().float() + first_call = False + pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") + patch_embeds = self.patch_embedding(pixel_values).cuda().half() + + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +def monkey_path_clip_vision_embed_forward(): + import transformers + + setattr( + transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, + "forward", + clip_vision_embed_forward, + ) + + +EntryClass = LlavaVidForCausalLM diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 99d81ce74..48e8d37fc 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -8,34 +8,28 @@ import torch import torch.nn.functional as F from torch import nn from transformers import MixtralConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.distributed import ( - tensor_model_parallel_all_reduce, -) -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class MixtralMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 9d157f81f..a242d013b 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional import torch from torch import nn from transformers import PretrainedConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.distributed import ( - get_tensor_model_parallel_world_size, -) -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class QWenMLP(nn.Module): @@ -132,7 +126,12 @@ class QWenAttention(nn.Module): class QWenBlock(nn.Module): - def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,): + def __init__( + self, + config: PretrainedConfig, + layer_id, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -181,7 +180,11 @@ class QWenBlock(nn.Module): class QWenModel(nn.Module): - def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -218,7 +221,11 @@ class QWenModel(nn.Module): class QWenLMHeadModel(nn.Module): - def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config self.transformer = QWenModel(config, quant_config=quant_config) @@ -276,4 +283,4 @@ class QWenLMHeadModel(nn.Module): weight_loader(param, loaded_weight) -EntryClass = QWenLMHeadModel \ No newline at end of file +EntryClass = QWenLMHeadModel diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index dc1dd0de3..45e5371e7 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple import torch from torch import nn +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.distributed import ( - get_tensor_model_parallel_world_size, -) -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator Qwen2Config = None @@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module): quant_config=quant_config, ) self.down_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, quant_config=quant_config, + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, ) if hidden_act != "silu": raise ValueError( diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 7ad495c95..423e603cd 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -7,35 +7,31 @@ from typing import Optional, Tuple import torch from torch import nn from transformers import PretrainedConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from vllm.distributed import ( - get_tensor_model_parallel_world_size, -) -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class StablelmMLP(nn.Module): def __init__( - self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -48,7 +44,10 @@ class StablelmMLP(nn.Module): quant_config=quant_config, ) self.down_proj = RowParallelLinear( - config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config, + config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quant_config, ) self.act_fn = SiluAndMul() @@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module): class StableLMEpochModel(nn.Module): def __init__( - self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 6f4a9b59f..3b1b99c8d 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -6,16 +6,13 @@ from typing import List, Optional import torch import torch.nn as nn from transformers import CLIPVisionModel, LlavaConfig -from sglang.srt.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) from sglang.srt.models.llava import ( LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward, ) +from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class YiVLForCausalLM(LlavaLlamaForCausalLM): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index dae23f7aa..d6eec0c90 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -107,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, raw_request) -def launch_server(server_args: ServerArgs, pipe_finish_writer): +def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None): global tokenizer_manager logging.basicConfig( @@ -140,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): ) # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args) + tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) proc_router = mp.Process( target=start_router_process, - args=( - server_args, - port_args, - pipe_router_writer, - ), + args=(server_args, port_args, pipe_router_writer, model_overide_args), ) proc_router.start() proc_detoken = mp.Process( @@ -170,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): if router_init_state != "init ok" or detoken_init_state != "init ok": proc_router.kill() proc_detoken.kill() - print(f"Initialization failed. router_init_state: {router_init_state}", flush=True) - print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True) + print( + f"Initialization failed. router_init_state: {router_init_state}", flush=True + ) + print( + f"Initialization failed. detoken_init_state: {detoken_init_state}", + flush=True, + ) sys.exit(1) assert proc_router.is_alive() and proc_detoken.is_alive() @@ -189,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): time.sleep(0.5) try: requests.get(url + "/get_model_info", timeout=5, headers=headers) + success = True # Set flag to True if request succeeds break except requests.exceptions.RequestException as e: pass @@ -205,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): }, }, headers=headers, - timeout=60, + timeout=600, ) assert res.status_code == 200 except Exception as e: @@ -235,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): class Runtime: def __init__( self, - log_evel="error", + log_evel: str = "error", + model_overide_args: Optional[dict] = None, *args, **kwargs, ): @@ -244,7 +247,10 @@ class Runtime: # Pre-allocate ports self.server_args.port, self.server_args.additional_ports = allocate_init_ports( - self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size) + self.server_args.port, + self.server_args.additional_ports, + self.server_args.tp_size, + ) self.url = self.server_args.url() self.generate_url = ( @@ -253,7 +259,10 @@ class Runtime: self.pid = None pipe_reader, pipe_writer = mp.Pipe(duplex=False) - proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer)) + proc = mp.Process( + target=launch_server, + args=(self.server_args, pipe_writer, model_overide_args), + ) proc.start() pipe_writer.close() self.pid = proc.pid @@ -265,7 +274,9 @@ class Runtime: if init_state != "init ok": self.shutdown() - raise RuntimeError("Initialization failed. Please see the error messages above.") + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) self.endpoint = RuntimeEndpoint(self.url) @@ -317,4 +328,4 @@ class Runtime: pos += len(cur) def __del__(self): - self.shutdown() \ No newline at end of file + self.shutdown() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ccf322c0a..65608af89 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -80,10 +80,12 @@ class ServerArgs: default=ServerArgs.tokenizer_path, help="The path of the tokenizer.", ) - parser.add_argument("--host", type=str, default=ServerArgs.host, - help="The host of the server.") - parser.add_argument("--port", type=int, default=ServerArgs.port, - help="The port of the server.") + parser.add_argument( + "--host", type=str, default=ServerArgs.host, help="The host of the server." + ) + parser.add_argument( + "--port", type=int, default=ServerArgs.port, help="The port of the server." + ) parser.add_argument( "--additional-ports", type=int, @@ -261,4 +263,4 @@ class PortArgs: router_port: int detokenizer_port: int nccl_port: int - model_rpc_ports: List[int] \ No newline at end of file + model_rpc_ports: List[int] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 11bf139bf..09849a547 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()): continue with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: s.bind(("", port)) + s.listen(1) # Attempt to listen on the port port_list.append(port) except socket.error: - pass + pass # If any error occurs, this port is not usable if len(port_list) == num: return port_list @@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel): def is_multimodal_model(model): - if isinstance(model, str): - return "llava" in model or "yi-vl" in model from sglang.srt.model_config import ModelConfig + if isinstance(model, str): + model = model.lower() + return "llava" in model or "yi-vl" in model or "llava-next" in model + if isinstance(model, ModelConfig): model_path = model.path.lower() - return "llava" in model_path or "yi-vl" in model_path - raise Exception("unrecognized type") + return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path + + raise ValueError("unrecognized type") + + +def decode_video_base64(video_base64): + from PIL import Image + + # Decode the base64 string + video_bytes = base64.b64decode(video_base64) + + # Placeholder for the start indices of each PNG image + img_starts = [] + + frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG")) + + assert frame_format in [ + "PNG", + "JPEG", + ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'" + + if frame_format == "PNG": + # Find each PNG start signature to isolate images + i = 0 + while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature + # Check if we found the start of a PNG file + if ( + video_bytes[i] == 0x89 + and video_bytes[i + 1] == 0x50 + and video_bytes[i + 2] == 0x4E + and video_bytes[i + 3] == 0x47 + and video_bytes[i + 4] == 0x0D + and video_bytes[i + 5] == 0x0A + and video_bytes[i + 6] == 0x1A + and video_bytes[i + 7] == 0x0A + ): + img_starts.append(i) + i += 8 # Skip the PNG signature + else: + i += 1 + else: + # Find each JPEG start (0xFFD8) to isolate images + i = 0 + while ( + i < len(video_bytes) - 1 + ): # Adjusted for the length of the JPEG SOI signature + # Check if we found the start of a JPEG file + if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8: + img_starts.append(i) + # Move to the next byte to continue searching for the next image start + i += 2 + else: + i += 1 + + frames = [] + for start_idx in img_starts: + # Assuming each image is back-to-back, the end of one image is the start of another + # The last image goes until the end of the byte string + end_idx = ( + img_starts[img_starts.index(start_idx) + 1] + if img_starts.index(start_idx) + 1 < len(img_starts) + else len(video_bytes) + ) + img_bytes = video_bytes[start_idx:end_idx] + + # Convert bytes to a PIL Image + img = Image.open(BytesIO(img_bytes)) + + # Convert PIL Image to a NumPy array + frame = np.array(img) + + # Append the frame to the list of frames + frames.append(frame) + + # Ensure there's at least one frame to avoid errors with np.stack + if frames: + return np.stack(frames, axis=0), img.size + else: + return np.array([]), ( + 0, + 0, + ) # Return an empty array and size tuple if no frames were found def load_image(image_file): from PIL import Image - image = None + image = image_size = None if image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) @@ -289,10 +373,13 @@ def load_image(image_file): elif image_file.startswith("data:"): image_file = image_file.split(",")[1] image = Image.open(BytesIO(base64.b64decode(image_file))) + elif image_file.startswith("video:"): + image_file = image_file.replace("video:", "") + image, image_size = decode_video_base64(image_file) else: image = Image.open(BytesIO(base64.b64decode(image_file))) - return image + return image, image_size def assert_pkg_version(pkg: str, min_version: str): @@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str): f"is less than the minimum required version {min_version}" ) except PackageNotFoundError: - raise Exception(f"{pkg} with minimum required version {min_version} is not installed") + raise Exception( + f"{pkg} with minimum required version {min_version} is not installed" + ) API_KEY_HEADER_NAME = "X-API-Key" diff --git a/python/sglang/srt/weight_utils.py b/python/sglang/srt/weight_utils.py index 0df3468c2..1170c6cfe 100644 --- a/python/sglang/srt/weight_utils.py +++ b/python/sglang/srt/weight_utils.py @@ -19,11 +19,12 @@ import torch from huggingface_hub import HfFileSystem, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm - from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import (QuantizationConfig, - get_quantization_config) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + get_quantization_config, +) from vllm.model_executor.layers.quantization.schema import QuantParamSchema logger = init_logger(__name__) @@ -32,17 +33,21 @@ logger = init_logger(__name__) # can share the same lock without error. # lock files in the temp directory will be automatically deleted when the # system reboots, so users will not complain about annoying lock files -temp_dir = os.environ.get('TMPDIR') or os.environ.get( - 'TEMP') or os.environ.get('TMP') or "/tmp/" +temp_dir = ( + os.environ.get("TMPDIR") + or os.environ.get("TEMP") + or os.environ.get("TMP") + or "/tmp/" +) def enable_hf_transfer(): - """automatically activates hf_transfer - """ + """automatically activates hf_transfer""" if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True except ImportError: pass @@ -65,8 +70,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): # add hash to avoid conflict with old users' lock files lock_file_name = hash_name + model_name + ".lock" # mode 0o666 is required for the filelock to be shared across users - lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), - mode=0o666) + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) return lock @@ -104,10 +108,12 @@ def convert_bin_to_safetensor_file( sf_size = os.stat(sf_filename).st_size pt_size = os.stat(pt_filename).st_size if (sf_size - pt_size) / pt_size > 0.01: - raise RuntimeError(f"""The file size different is more than 1%: + raise RuntimeError( + f"""The file size different is more than 1%: - {sf_filename}: {sf_size} - {pt_filename}: {pt_size} - """) + """ + ) # check if the tensors are the same reloaded = load_file(sf_filename) @@ -122,8 +128,7 @@ def convert_bin_to_safetensor_file( def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) # Read the quantization config from the HF model config, if available. - hf_quant_config = getattr(model_config.hf_config, "quantization_config", - None) + hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) model_name_or_path = model_config.model @@ -131,26 +136,29 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: if not is_local: # Download the config files. with get_lock(model_name_or_path, model_config.download_dir): - hf_folder = snapshot_download(model_name_or_path, - revision=model_config.revision, - allow_patterns="*.json", - cache_dir=model_config.download_dir, - tqdm_class=Disabledtqdm) + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=model_config.download_dir, + tqdm_class=Disabledtqdm, + ) else: hf_folder = model_name_or_path config_files = glob.glob(os.path.join(hf_folder, "*.json")) quant_config_files = [ - f for f in config_files if any( - f.endswith(x) for x in quant_cls.get_config_filenames()) + f + for f in config_files + if any(f.endswith(x) for x in quant_cls.get_config_filenames()) ] if len(quant_config_files) == 0: - raise ValueError( - f"Cannot find the config file for {model_config.quantization}") + raise ValueError(f"Cannot find the config file for {model_config.quantization}") if len(quant_config_files) > 1: raise ValueError( f"Found multiple config files for {model_config.quantization}: " - f"{quant_config_files}") + f"{quant_config_files}" + ) quant_config_file = quant_config_files[0] with open(quant_config_file, "r") as f: @@ -166,8 +174,7 @@ def prepare_hf_model_weights( revision: Optional[str] = None, ) -> Tuple[str, List[str], bool]: # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) \ - and load_format != "tensorizer" + is_local = os.path.isdir(model_name_or_path) and load_format != "tensorizer" use_safetensors = False # Some quantized models use .pt files for storing the weights. if load_format == "auto": @@ -203,11 +210,13 @@ def prepare_hf_model_weights( # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download(model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=Disabledtqdm, - revision=revision) + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=Disabledtqdm, + revision=revision, + ) else: hf_folder = model_name_or_path hf_weights_files: List[str] = [] @@ -228,16 +237,14 @@ def prepare_hf_model_weights( "scaler.pt", ] hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) + f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) ] if load_format == "tensorizer": return hf_folder, hf_weights_files, use_safetensors if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") + raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`") return hf_folder, hf_weights_files, use_safetensors @@ -254,7 +261,8 @@ def hf_model_weights_iterator( cache_dir=cache_dir, load_format=load_format, fall_back_to_pt=fall_back_to_pt, - revision=revision) + revision=revision, + ) if load_format == "npcache": # Currently np_cache only support *.bin checkpoints @@ -289,22 +297,25 @@ def hf_model_weights_iterator( param = np.load(f) yield name, torch.from_numpy(param) elif load_format == "tensorizer": - from vllm.model_executor.tensorizer_loader import (TensorDeserializer, - open_stream, - tensorizer_warning) + from vllm.model_executor.tensorizer_loader import ( + TensorDeserializer, + open_stream, + tensorizer_warning, + ) + tensorizer_args = load_format.params tensorizer_warning( "Deserializing HuggingFace models is not optimized for " "loading on vLLM, as tensorizer is forced to load to CPU. " "Consider deserializing a vLLM model instead for faster " "load times. See the examples/tensorize_vllm_model.py example " - "script for serializing vLLM models.") + "script for serializing vLLM models." + ) deserializer_args = tensorizer_args.deserializer_params stream_params = tensorizer_args.stream_params stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) - with TensorDeserializer(stream, **deserializer_args, - device="cpu") as state: + with TensorDeserializer(stream, **deserializer_args, device="cpu") as state: for name, param in state.items(): yield name, param del state @@ -324,8 +335,12 @@ def hf_model_weights_iterator( def kv_cache_scales_loader( - filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, - model_type: Optional[str]) -> Iterable[Tuple[int, float]]: + filename: str, + tp_rank: int, + tp_size: int, + num_hidden_layers: int, + model_type: Optional[str], +) -> Iterable[Tuple[int, float]]: """ A simple utility to read in KV cache scaling factors that have been previously serialized to disk. Used by the model to populate the appropriate @@ -343,8 +358,7 @@ def kv_cache_scales_loader( "tp_size": tp_size, } schema_dct = json.load(f) - schema = QuantParamSchema.model_validate(schema_dct, - context=context) + schema = QuantParamSchema.model_validate(schema_dct, context=context) layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] return layer_scales_map.items() @@ -357,9 +371,11 @@ def kv_cache_scales_loader( # This section is reached if and only if any of the excepts are hit # Return an empty iterable (list) => no KV cache scales are loaded # which ultimately defaults to 1.0 scales - logger.warning("Defaulting to KV cache scaling factors = 1.0 " - f"for all layers in TP rank {tp_rank} " - "as an error occurred during loading.") + logger.warning( + "Defaulting to KV cache scaling factors = 1.0 " + f"for all layers in TP rank {tp_rank} " + "as an error occurred during loading." + ) return [] @@ -378,8 +394,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: return x -def default_weight_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) @@ -399,4 +414,4 @@ def initialize_dummy_weights( """ for param in model.state_dict().values(): if torch.is_floating_point(param): - param.data.uniform_(low, high) \ No newline at end of file + param.data.uniform_(low, high) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index bbe9e0844..51bb9b20b 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -2,13 +2,16 @@ import base64 import json +import os import sys import threading import traceback import urllib.request +from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps +import numpy as np import requests @@ -110,6 +113,74 @@ def encode_image_base64(image_path): return base64.b64encode(buffered.getvalue()).decode("utf-8") +def encode_frame(frame): + import cv2 # pip install opencv-python-headless + from PIL import Image + + # Convert the frame to RGB (OpenCV uses BGR by default) + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Convert the frame to PIL Image to easily convert to bytes + im_pil = Image.fromarray(frame) + + # Convert to bytes + buffered = BytesIO() + + # frame_format = str(os.getenv('FRAME_FORMAT', "JPEG")) + + im_pil.save(buffered, format="PNG") + + frame_bytes = buffered.getvalue() + + # Return the bytes of the frame + return frame_bytes + + +def encode_video_base64(video_path, num_frames=16): + import cv2 + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Could not open video file:{video_path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"target_frames: {num_frames}") + + frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if ret: + frames.append(frame) + else: + # Handle the case where the frame could not be read + # print(f"Warning: Could not read frame at index {i}.") + pass + + cap.release() + + # Safely select frames based on frame_indices, avoiding IndexError + frames = [frames[i] for i in frame_indices if i < len(frames)] + + # If there are not enough frames, duplicate the last frame until we reach the target + while len(frames) < num_frames: + frames.append(frames[-1]) + + # Use ThreadPoolExecutor to process and encode frames in parallel + with ThreadPoolExecutor() as executor: + encoded_frames = list(executor.map(encode_frame, frames)) + + # encoded_frames = list(map(encode_frame, frames)) + + # Concatenate all frames bytes + video_bytes = b"".join(encoded_frames) + + # Encode the concatenated bytes to base64 + video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8") + + return video_base64 + + def _is_chinese_char(cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: @@ -170,4 +241,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None): if not ret_value: raise RuntimeError() - return ret_value[0] \ No newline at end of file + return ret_value[0]