Cleanup readme, llava examples, usage examples and nccl init (#1194)

This commit is contained in:
Lianmin Zheng
2024-08-24 08:02:23 -07:00
committed by GitHub
parent c9064e6fd9
commit f6af3a6561
65 changed files with 174 additions and 317 deletions

View File

@@ -22,12 +22,13 @@ The core features include:
## News
- [2024/07] 🔥 Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)).
- [2024/08] 🔥 LLaVA-OneVision with single-image, multi-image and video are supported ([blog](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)).
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
<details>
<summary>More</summary>
- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)).
- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
@@ -227,19 +228,14 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- Gemma / Gemma 2
- Qwen / Qwen 2 / Qwen 2 MoE
- DeepSeek / DeepSeek 2
- LLaVA 1.5 / 1.6
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 30000`
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --host=127.0.0.1 --tp-size=1 --chat-template=llava_llama_3`
- `python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --host="127.0.0.1" --tp-size=8 --chat-template=chatml-llava`
- LLaVA-NeXT-Video
- see [examples/usage/llava_video](examples/usage/llava_video)
- [LLaVA-OneVision](https://arxiv.org/abs/2408.03326)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384`
- see [test/srt/test_llava_onevision_openai_server.py](test/srt/test_llava_onevision_openai_server.py)
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384`
- Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py)
- LLaVA 1.5 / 1.6 / NeXT
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3`
- `python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8 --chat-template=chatml-llava`
- Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py)
- Yi-VL
- see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py).
- StableLM
- Command-R
- DBRX
@@ -250,6 +246,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md).
#### Use Models From ModelScope
<details>
To use model from [ModelScope](https://www.modelscope.cn), setting environment variable SGLANG_USE_MODELSCOPE.
```
export SGLANG_USE_MODELSCOPE=true
@@ -258,21 +256,20 @@ Launch [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instru
```
SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000
```
</details>
#### Run Llama 3.1 405B
```bash
## Run 405B (fp8) on a single node
# Run 405B (fp8) on a single node
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8
## Run 405B (fp16) on two nodes
# replace the `172.16.4.52:20000` with your own first node ip address and port, disable CUDA Graph temporarily
# Run 405B (fp16) on two nodes
## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port
GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph
# on the first node
GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph --mem-frac 0.75
# on the second
GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph --mem-frac 0.75
## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port
GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph
```
### Benchmark Performance

View File

@@ -1,5 +1,8 @@
# Sampling Parameters in SGLang Runtime
This doc describes the sampling parameters of the SGLang Runtime.
It is the low-level endpoint of the runtime.
If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API
](https://github.com/sgl-project/sglang?tab=readme-ov-file#openai-compatible-api).
The `/generate` endpoint accepts the following arguments in the JSON format.
@@ -140,7 +143,7 @@ print("")
Launch a server
```
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --chat-template chatml-llava
```
Download an image
@@ -155,7 +158,9 @@ import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"text": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n"
"<|im_start|>assistant\n",
"image_data": "example_image.png",
"sampling_params": {
"temperature": 0,

View File

Before

Width:  |  Height:  |  Size: 337 KiB

After

Width:  |  Height:  |  Size: 337 KiB

View File

Before

Width:  |  Height:  |  Size: 407 KiB

After

Width:  |  Height:  |  Size: 407 KiB

View File

@@ -1,6 +1,6 @@
"""
Usage:
python3 srt_example_chat.py
python3 local_example_chat.py
"""
import sglang as sgl

View File

@@ -1,6 +1,6 @@
"""
Usage:
python3 srt_example_complete.py
python3 local_example_complete.py
"""
import sglang as sgl

View File

@@ -1,8 +1,14 @@
"""
Usage: python3 srt_example_llava.py
Usage: python3 local_example_llava_next.py
"""
from PIL import ImageFile
import sglang as sgl
from sglang.lang.chat_template import get_chat_template
from sglang.srt.utils import load_image
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
@sgl.function
@@ -44,10 +50,17 @@ def batch():
if __name__ == "__main__":
runtime = sgl.Runtime(
model_path="liuhaotian/llava-v1.6-vicuna-7b",
tokenizer_path="llava-hf/llava-1.5-7b-hf",
)
import multiprocessing as mp
mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b")
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
# Or you can use the 72B model
# runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8)
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
sgl.set_default_backend(runtime)
print(f"chat template: {runtime.endpoint.chat_template.name}")

View File

@@ -1,7 +1,8 @@
"""
Usage:
pip install opencv-python-headless
python3 srt_example_llava.py
python3 srt_example_llava_v.py
"""
import argparse
@@ -9,6 +10,8 @@ import csv
import os
import time
import requests
import sglang as sgl

View File

@@ -1,70 +0,0 @@
"""
Usage: python3 srt_example_yi_vl.py
Requirements: transformers==4.38
"""
import sglang as sgl
@sgl.function
def image_qa(s, image_path, question):
s += sgl.user(sgl.image(image_path) + question)
s += sgl.assistant(sgl.gen("answer"))
def single():
state = image_qa.run(
image_path="images/cat.jpeg",
question="What is this?",
max_new_tokens=64,
stop="###",
)
print(state["answer"], "\n")
def stream():
state = image_qa.run(
image_path="images/cat.jpeg",
question="What is this?",
max_new_tokens=64,
stream=True,
stop="###",
)
for out in state.text_iter("answer"):
print(out, end="", flush=True)
print()
def batch():
states = image_qa.run_batch(
[
{"image_path": "images/cat.jpeg", "question": "What is this?"},
{"image_path": "images/dog.jpeg", "question": "What is this?"},
],
max_new_tokens=64,
stop="###",
)
for s in states:
print(s["answer"], "\n")
if __name__ == "__main__":
runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-6B")
# runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-34B")
sgl.set_default_backend(runtime)
# Run a single request
print("\n========== single ==========\n")
single()
# Stream output
print("\n========== stream ==========\n")
stream()
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
runtime.shutdown()

View File

@@ -4,7 +4,7 @@ Usage:
# Installing latest sglang.
# Endpoint Service CLI:
# python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --tokenizer-path lmms-lab/llama3-llava-next-8b-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
python3 http_llama3_llava_test.py
@@ -16,7 +16,6 @@ import argparse
import asyncio
import copy
import json
import time
import aiohttp
import requests

View File

@@ -1,3 +1,11 @@
"""
Usage:
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384
python3 http_llava_onevision_test.py
"""
import base64
import io
import os
@@ -74,7 +82,6 @@ def video_stream_request_test(client, video_path):
print("------------------------Video Stream Request Test----------------------")
messages = prepare_video_messages(video_path)
start_time = time.time()
video_request = client.chat.completions.create(
model="default",
messages=messages,

View File

@@ -4,7 +4,7 @@ Usage:
# Installing latest sglang.
# Endpoint Service CLI:
# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
python3 http_qwen_llava_test.py
@@ -16,7 +16,6 @@ import argparse
import asyncio
import copy
import json
import time
import aiohttp
import requests

View File

@@ -1,90 +0,0 @@
"""
Usage: python3 srt_example_llava.py
"""
from PIL import ImageFile
import sglang as sgl
from sglang.lang.chat_template import get_chat_template
from sglang.srt.utils import load_image
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
@sgl.function
def image_qa(s, image, question):
s += sgl.user(sgl.image(image) + question)
s += sgl.assistant(sgl.gen("answer"))
def single():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512)
print(state["answer"], "\n")
def stream():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
state = image_qa.run(
image=pil_image,
question="Please generate short caption for this image.",
max_new_tokens=512,
temperature=0,
stream=True,
)
for out in state.text_iter("answer"):
print(out, end="", flush=True)
print()
def batch():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
states = image_qa.run_batch(
[
{"image": pil_image, "question": "What is this?"},
{"image": pil_image, "question": "What is this?"},
],
max_new_tokens=512,
)
for s in states:
print(s["answer"], "\n")
if __name__ == "__main__":
import multiprocessing as mp
mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(
model_path="lmms-lab/llama3-llava-next-8b",
tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer",
)
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
# runtime = sgl.Runtime(
# model_path="lmms-lab/llava-next-72b",
# tokenizer_path="lmms-lab/llavanext-qwen-tokenizer",
# )
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
sgl.set_default_backend(runtime)
print(f"chat template: {runtime.endpoint.chat_template.name}")
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
# Run a single request
print("\n========== single ==========\n")
single()
# Stream output
print("\n========== stream ==========\n")
stream()
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
runtime.shutdown()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

View File

@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(path=server_args.model_path)
model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum):

View File

@@ -1,29 +0,0 @@
"""Launch the inference server for Llava-video model."""
import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = 16
model_overide_args["model_type"] = "llavavid"
if model_overide_args["num_frames"] == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
model_overide_args["model_max_length"] = 4096 * 2
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, model_overide_args, None)

View File

@@ -26,7 +26,7 @@ import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False):
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:

View File

@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
weight_name: str,
shard_id: int,
expert_id: int,
pre_sharded: bool,
use_presharded_weights: bool = False,
):
param_data = param.data
@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
if pre_sharded:
if use_presharded_weights:
shard = slice(None)
else:
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)

View File

@@ -180,7 +180,7 @@ class LogitsProcessor(nn.Module):
if hasattr(self.config, "final_logit_softcapping"):
last_logits.div_(self.config.final_logit_softcapping)
last_logits = torch.tanh(last_logits)
torch.tanh(last_logits, out=last_logits)
last_logits.mul_(self.config.final_logit_softcapping)
# Return only last_logits if logprob is not requested
@@ -241,7 +241,7 @@ class LogitsProcessor(nn.Module):
if hasattr(self.config, "final_logit_softcapping"):
all_logits.div_(self.config.final_logit_softcapping)
all_logits = torch.tanh(all_logits)
torch.tanh(all_logits, out=all_logits)
all_logits.mul_(self.config.final_logit_softcapping)
all_logprobs = all_logits

View File

@@ -35,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
global_server_args_dict = {
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False,
"triton_attention_reduce_in_fp32": False,
"enable_mla": False,
}

View File

@@ -606,6 +606,9 @@ class TokenizerManager:
return background_tasks
def create_handle_loop(self):
if not self.to_create_loop:
return
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())

View File

@@ -20,7 +20,6 @@ import importlib
import importlib.resources
import logging
import pkgutil
import warnings
from functools import lru_cache
from typing import Optional, Type
@@ -91,23 +90,35 @@ class ModelRunner:
{
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
}
)
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
self.init_cuda_graphs()
def init_torch_distributed(self):
# Init torch distributed
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
if not server_args.enable_p2p_check:
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)
if server_args.nccl_init_addr:
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
if self.server_args.nccl_init_addr:
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not server_args.disable_custom_all_reduce)
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment(
backend="nccl",
world_size=self.tp_size,
@@ -116,32 +127,28 @@ class ModelRunner:
distributed_init_method=nccl_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
min_per_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if total_local_gpu_memory < total_gpu_memory * 0.9:
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
# Load the model and create memory pool
self.load_model()
self.init_memory_pool(
total_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
if self.is_generation:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self.init_cuda_graphs()
return min_per_gpu_memory
def load_model(self):
logger.info(
@@ -150,7 +157,7 @@ class ModelRunner:
)
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80 use float16 due to lack of bfloat16 support."
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
@@ -168,8 +175,9 @@ class ModelRunner:
skip_tokenizer_init=True,
)
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()
@@ -191,8 +199,8 @@ class ModelRunner:
cache_config=None,
)
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.is_generation = is_generation_model(
@@ -206,7 +214,8 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def update_weights(self, model_path, load_format):
def update_weights(self, model_path: str, load_format: str):
"""Update weights in-place."""
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
@@ -222,6 +231,7 @@ class ModelRunner:
target_device = torch.device(self.device_config.device)
try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
@@ -291,7 +301,7 @@ class ModelRunner:
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory):
def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
@@ -319,7 +329,10 @@ class ModelRunner:
return max_num_token
def init_memory_pool(
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
self,
total_gpu_memory: int,
max_num_reqs: int = None,
max_total_tokens: int = None,
):
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None:
@@ -388,6 +401,7 @@ class ModelRunner:
return c
def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.disable_flashinfer:
assert (
self.sliding_window_size is None
@@ -448,6 +462,11 @@ class ModelRunner:
)
def init_cuda_graphs(self):
"""Capture cuda graphs."""
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
@@ -457,7 +476,12 @@ class ModelRunner:
logger.info(
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
)
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
else:
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.cuda_graph_runner = CudaGraphRunner(
self,
max_batch_size_to_capture=max(batch_size_list),

View File

@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_window_size(config):
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
sliding_window_size=get_window_size(config) if use_sliding_window else None,
sliding_window_size=(
get_attention_sliding_window_size(config)
if use_sliding_window
else None
),
logit_cap=self.config.attn_logit_softcapping,
)
@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
def get_window_size(self):
return get_window_size(self.config)
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [

View File

@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Grok1Model(config, quant_config=quant_config)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
self.use_presharded_weights = True
warnings.filterwarnings("ignore", category=FutureWarning)
def forward(
@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
continue
name = name.replace(weight_name, param_name)
if self.use_presharded_weights:
extra_kwargs = {
"use_presharded_weights": self.use_presharded_weights
}
else:
extra_kwargs = {}
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
weight_name,
shard_id=shard_id,
expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1,
**extra_kwargs,
)
break
else:

View File

@@ -81,13 +81,12 @@ class ServerArgs:
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
enable_p2p_check: bool = False
enable_mla: bool = False
attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False
disable_custom_all_reduce: bool = False
triton_attention_reduce_in_fp32: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
@@ -404,6 +403,12 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
@@ -425,7 +430,7 @@ class ServerArgs:
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--attention-reduce-in-fp32",
"--triton-attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
@@ -435,12 +440,6 @@ class ServerArgs:
action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):

View File

@@ -347,7 +347,7 @@ def suppress_other_loggers():
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
def assert_pkg_version(pkg: str, min_version: str, message: str):
@@ -451,10 +451,6 @@ def monkey_patch_vllm_dummy_weight_loader():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.

View File

@@ -24,7 +24,6 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
@@ -63,8 +62,8 @@ class HFRunner:
def __init__(
self,
model_path,
torch_dtype=torch.float16,
is_generation_model=None,
torch_dtype,
is_generation_model,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
@@ -90,11 +89,8 @@ class HFRunner:
trust_remote_code=True,
)
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
self.is_generation_model = is_generation_model
if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
@@ -176,16 +172,12 @@ class SRTRunner:
def __init__(
self,
model_path,
torch_dtype,
is_generation_model,
tp_size=1,
torch_dtype=torch.float16,
is_generation_model=None,
port=5157,
):
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
self.is_generation_model = is_generation_model
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,

View File

@@ -59,7 +59,7 @@ class TestEmbeddingModels(unittest.TestCase):
tolerance = 1e-2
assert torch.all(
abs(similarities - 1) < tolerance
), f"embeddings not all close"
), "embeddings are not all close"
def test_prefill_logits(self):
for model, tp_size in MODELS:

View File

@@ -59,7 +59,7 @@ class TestGenerationModels(unittest.TestCase):
tolerance = 3e-2
assert torch.all(
abs(hf_logprobs - srt_logprobs) < tolerance
), f"prefill logprobs not all close"
), "prefill logprobs are not all close"
print(hf_outputs.output_strs)
print(srt_outputs.output_strs)

View File

@@ -14,7 +14,7 @@ suites = {
"test_torch_compile.py",
"test_triton_attn_backend.py",
"test_vision_openai_server.py",
"test_large_max_new_tokens.py",
"test_update_weights.py",
"models/test_generation_models.py",
"models/test_embedding_models.py",
"sampling/penaltylib",

View File

@@ -2,8 +2,6 @@ import base64
import io
import json
import os
import sys
import time
import unittest
import numpy as np
@@ -12,12 +10,10 @@ import requests
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import DEFAULT_URL_FOR_UNIT_TEST, popen_launch_server
# python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --tokenizer-path lmms-lab/llavanext-qwen-siglip-tokenizer --port=30000 --host=127.0.0.1 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384
class TestOpenAIVisionServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
@@ -32,11 +28,9 @@ class TestOpenAIVisionServer(unittest.TestCase):
other_args=[
"--chat-template",
"chatml-llava",
"--tokenizer-path",
"lmms-lab/llavanext-qwen-siglip-tokenizer",
"--chunked-prefill-size",
"16384",
"--log-requests",
# "--log-requests",
],
)
cls.base_url += "/v1"
@@ -132,7 +126,6 @@ class TestOpenAIVisionServer(unittest.TestCase):
messages = self.prepare_video_messages(file_path)
start_time = time.time()
video_request = client.chat.completions.create(
model="default",
messages=messages,
@@ -140,15 +133,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
max_tokens=1024,
stream=True,
)
print("-" * 30)
video_response = ""
for chunk in video_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
video_response += content
sys.stdout.write(content)
sys.stdout.flush()
print(content, end="", flush=True)
print("-" * 30)
# Add assertions to validate the video response