adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

425
3rdparty/amd/profiling/PROFILING.md vendored Normal file
View File

@@ -0,0 +1,425 @@
## Profiling SGLang Infer System with AMD GPUs
This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too.
Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations.
Two primary methods are covered:
- [RPD](https://github.com/ROCm/rocmProfileData.git)
- [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
### Profiling SGLang Infer System with RPD Profiler
RPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below:
1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory.
install_rpd.sh
```bash
# download and install RPD
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
# install rpd module
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
cd rocmProfileData
git checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac
git apply rpd.patch
make && make install
cd rocpd_python && python setup.py install && cd ..
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
```
rpd.patch
```bash
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
index e9d9feb..b2e9e1a 100644
--- a/rpd_tracer/Makefile
+++ b/rpd_tracer/Makefile
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
$(info Building with roctracer)
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
+ RPD_SRCS += RoctracerDataSource.cpp
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
endif
```
2. Add loadTracer.sh file included in this directory to /sglang/python/sglang.
loadTracer.sh
```bash
#!/bin/bash
################################################################################
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
################################################################################
OUTPUT_FILE="trace.rpd"
if [ "$1" = "-o" ] ; then
OUTPUT_FILE=$2
shift
shift
fi
if [ -e ${OUTPUT_FILE} ] ; then
rm ${OUTPUT_FILE}
fi
python3 -m rocpd.schema --create ${OUTPUT_FILE}
if [ $? != 0 ] ; then
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
exit
fi
export RPDT_FILENAME=${OUTPUT_FILE}
export RPDT_AUTOSTART=0
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
```
3. Apply patch (provided in this directory) with "git apply rpd_profile_server_enable.patch" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info.
#### Common Notes 1
Please note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size.
rpd_profile_server_enable.patch
```bash
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 62d1ff9..9021c01 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
+from rpdTracerControl import rpdTracerControl
+rpdTracerControl.skipCreate()
logger = logging.getLogger(__name__)
@@ -245,6 +247,7 @@ class Scheduler:
],
with_stack=True,
)
+ self.rpd = rpdTracerControl()
@torch.inference_mode()
def event_loop(self):
@@ -1027,15 +1030,24 @@ class Scheduler:
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
- self.profiler.start()
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
+ if self.tp_rank == 0 or self.tp_rank == 1:
+ self.rpd.start()
+ self.rpd.rangePush("", "rpd profile range", "")
+ logger.info("rpd is enabled")
def stop_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
- self.profiler.stop()
- self.profiler.export_chrome_trace(
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
- )
+ #self.profiler.stop()
+ #self.profiler.export_chrome_trace(
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
+ #)
+ if self.tp_rank ==0 or self.tp_rank ==1:
+ self.rpd.rangePop()
+ self.rpd.stop()
+ self.rpd.flush()
+ logger.info("rpd is done")
logger.info("Profiler is done")
```
#### Advanced Debugging with RPD Profiler
Sometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch "git apply rpd_profile_server_enable_wCPU_activities.patch", where 3 files are modified.
rpd_profile_server_enable_wCPU_activities.patch
```bash
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 62d1ff9..2edb427 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
+from rpdTracerControl import rpdTracerControl
+rpdTracerControl.skipCreate()
logger = logging.getLogger(__name__)
@@ -245,6 +247,7 @@ class Scheduler:
],
with_stack=True,
)
+ self.rpd = rpdTracerControl()
@torch.inference_mode()
def event_loop(self):
@@ -1027,15 +1030,26 @@ class Scheduler:
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
- self.profiler.start()
+ #self.profiler.start()
+ logger.info("torch profiler is disabled")
+ if self.tp_rank == 0 or self.tp_rank == 1:
+ self.rpd.setPythonTrace(True)
+ self.rpd.start()
+ self.rpd.rangePush("", "scheduler", "")
+ logger.info("rpd is enabled inside scheduler profiling")
def stop_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
- self.profiler.stop()
- self.profiler.export_chrome_trace(
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
- )
+ #self.profiler.stop()
+ #self.profiler.export_chrome_trace(
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
+ #)
+ if self.tp_rank ==0 or self.tp_rank ==1:
+ self.rpd.rangePop()
+ self.rpd.stop()
+ self.rpd.flush()
+ logger.info("rpd is done inside scheduler")
logger.info("Profiler is done")
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 2621ccd..181df85 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_generation_model, is_multimodal_model
+from rpdTracerControl import rpdTracerControl
+rpdTracerControl.skipCreate()
+
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger(__name__)
@@ -514,10 +518,20 @@ class TokenizerManager:
self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
+ rpd = rpdTracerControl()
+ rpd.setPythonTrace(True)
+ rpd.start()
+ rpd.rangePush("", "tokenizer_manager", "")
+ logger.info("tokenizer_manager rpd profiling started!")
req = ProfileReq.START_PROFILE
self.send_to_scheduler.send_pyobj(req)
def stop_profile(self):
+ rpd = rpdTracerControl()
+ rpd.rangePop()
+ rpd.stop()
+ rpd.flush()
+ logger.info("rpd profiling is done inside tokenizer_manager!")
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index 7111c93..2bd722c 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -30,6 +30,8 @@ import threading
import time
from http import HTTPStatus
from typing import Dict, List, Optional, Union
+from rpdTracerControl import rpdTracerControl
+rpdTracerControl.skipCreate()
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -152,6 +154,11 @@ async def flush_cache():
@app.post("/start_profile")
async def start_profile():
"""Start profiling."""
+ rpd = rpdTracerControl()
+ rpd.setPythonTrace(True)
+ rpd.start()
+ rpd.rangePush("", "server rpd profile range", "")
+ logger.info("rpd profiling started in server.py!")
tokenizer_manager.start_profile()
return Response(
content="Start profiling.\n",
@@ -164,6 +171,11 @@ async def start_profile():
async def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
+ rpd = rpdTracerControl()
+ rpd.rangePop()
+ rpd.stop()
+ rpd.flush()
+ logger.info("rpd profiling is done in server.py!")
return Response(
content="Stop profiling. This will take some time.\n",
status_code=200,
```
4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for "--model-path" if you want to use the example server.sh file provided.
```bash
cat ../dummy_grok1/config.json
{
"architectures": [
"Grok1ModelForCausalLM"
],
"embedding_multiplier_scale": 78.38367176906169,
"output_multiplier_scale": 0.5773502691896257,
"vocab_size": 131072,
"hidden_size": 6144,
"intermediate_size": 32768,
"max_position_embeddings": 8192,
"num_experts_per_tok": 2,
"num_local_experts": 8,
"num_attention_heads": 48,
"num_hidden_layers": 64,
"num_key_value_heads": 8,
"head_dim": 128,
"rms_norm_eps": 1e-05,
"rope_theta": 10000.0,
"model_type": "mixtral",
"torch_dtype": "bfloat16"
}
```
5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container.
#### Common Notes 2
- Remember to change model-path to the correct path
- loadTracer.sh is needed to conduct profiling
- SGLANG_TORCH_PROFILER_DIR is used for default torch profiler
- Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server.
server.sh
```bash
#!/bin/bash
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
# Get the current timestamp
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
# Define the log file with a timestamp
LOGFILE="sglang_server_log_$TIMESTAMP.json"
# Run the Python command and save the output to the log file
loadTracer.sh python3 -m sglang.launch_server \
--model-path /sgl-workspace/sglang/dummy_grok1 \
--tokenizer-path Xenova/grok-1-tokenizer \
--load-format dummy \
--quantization fp8 \
--tp 8 \
--port 30000 \
--disable-radix-cache 2>&1 | tee "$LOGFILE"
```
6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see "The server is fired up and is ready to roll!" message from server side terminal.
#### Common Notes 3
- Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details.
- Please don't use RPD profiler together with PyTorch profiler to avoid interference.
- The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file.
client.sh
```bash
#!/bin/bash
# Start profiling via API
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
# Benchmark serving using sglang with random dataset and tokenizer
# Define the log file with a timestamp
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
LOGFILE="sglang_client_log_$TIMESTAMP.json"
# Run the benchmark with specified parameters and save logs
python3 -m sglang.bench_serving \
--backend sglang \
--tokenizer Xenova/grok-1-tokenizer \
--dataset-name random \
--random-input 1024\
--random-output 1024 \
--num-prompts 120 \
--request-rate 8 \
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
# Stop profiling via API
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
# Convert tracing file to csv & json
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
python3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
```
7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB.
### Profiling SGLang Infer System with PyTorch Profiler
Please use the steps as follows:
1. Apply the patch torch_profiler.patch. Note that you can modify "if self.tp_rank == 0" in the patch to allow more ranks be recorded in profiling.
torch_profiler.patch
```bash
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 62d1ff9..6ecd78c 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -240,7 +240,6 @@ class Scheduler:
)
self.profiler = torch.profiler.profile(
activities=[
- torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
@@ -1033,9 +1032,11 @@ class Scheduler:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
- self.profiler.export_chrome_trace(
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
- )
+ if self.tp_rank == 0:
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
+ print("Profiling stats done.")
+
logger.info("Profiler is done")
```
2. Create the model path directory and copy it to the right path for "--model-path" if you want to use the server.sh file provided.
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
-------
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)

27
3rdparty/amd/profiling/client.sh vendored Executable file
View File

@@ -0,0 +1,27 @@
#!/bin/bash
# Start profiling via API
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
# Benchmark serving using sglang with random dataset and tokenizer
# Define the log file with a timestamp
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
LOGFILE="sglang_client_log_$TIMESTAMP.json"
# Run the benchmark with specified parameters and save logs
python3 -m sglang.bench_serving \
--backend sglang \
--tokenizer Xenova/grok-1-tokenizer \
--dataset-name random \
--random-input 1024\
--random-output 1024 \
--num-prompts 240 \
--request-rate 8 \
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
# Stop profiling via API
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
# Convert tracing file to csv & json
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
python3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json

10
3rdparty/amd/profiling/install_rpd.sh vendored Normal file
View File

@@ -0,0 +1,10 @@
# download and install RPD
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
# install rpd module
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
cd rocmProfileData
git apply rpd.patch
make && make install
cd rocpd_python && python setup.py install && cd ..
cd rpd_tracer && make clean;make install && python setup.py install && cd ..

43
3rdparty/amd/profiling/loadTracer.sh vendored Executable file
View File

@@ -0,0 +1,43 @@
#!/bin/bash
################################################################################
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
################################################################################
OUTPUT_FILE="trace.rpd"
if [ "$1" = "-o" ] ; then
OUTPUT_FILE=$2
shift
shift
fi
if [ -e ${OUTPUT_FILE} ] ; then
rm ${OUTPUT_FILE}
fi
python3 -m rocpd.schema --create ${OUTPUT_FILE}
if [ $? != 0 ] ; then
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
exit
fi
export RPDT_FILENAME=${OUTPUT_FILE}
export RPDT_AUTOSTART=0
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"

12
3rdparty/amd/profiling/rpd.patch vendored Normal file
View File

@@ -0,0 +1,12 @@
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
index e9d9feb..b2e9e1a 100644
--- a/rpd_tracer/Makefile
+++ b/rpd_tracer/Makefile
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
$(info Building with roctracer)
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
+ RPD_SRCS += RoctracerDataSource.cpp
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
endif

View File

@@ -0,0 +1,49 @@
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 62d1ff9..9021c01 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
+from rpdTracerControl import rpdTracerControl
+rpdTracerControl.skipCreate()
logger = logging.getLogger(__name__)
@@ -245,6 +247,7 @@ class Scheduler:
],
with_stack=True,
)
+ self.rpd = rpdTracerControl()
@torch.inference_mode()
def event_loop(self):
@@ -1027,15 +1030,24 @@ class Scheduler:
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
- self.profiler.start()
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
+ if self.tp_rank == 0 or self.tp_rank == 1:
+ self.rpd.start()
+ self.rpd.rangePush("", "rpd profile range", "")
+ logger.info("rpd is enabled")
def stop_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
- self.profiler.stop()
- self.profiler.export_chrome_trace(
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
- )
+ #self.profiler.stop()
+ #self.profiler.export_chrome_trace(
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
+ #)
+ if self.tp_rank ==0 or self.tp_rank ==1:
+ self.rpd.rangePop()
+ self.rpd.stop()
+ self.rpd.flush()
+ logger.info("rpd is done")
logger.info("Profiler is done")

View File

@@ -0,0 +1,126 @@
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 62d1ff9..2edb427 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
+from rpdTracerControl import rpdTracerControl
+rpdTracerControl.skipCreate()
logger = logging.getLogger(__name__)
@@ -245,6 +247,7 @@ class Scheduler:
],
with_stack=True,
)
+ self.rpd = rpdTracerControl()
@torch.inference_mode()
def event_loop(self):
@@ -1027,15 +1030,26 @@ class Scheduler:
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
- self.profiler.start()
+ #self.profiler.start()
+ logger.info("torch profiler is disabled")
+ if self.tp_rank == 0 or self.tp_rank == 1:
+ self.rpd.setPythonTrace(True)
+ self.rpd.start()
+ self.rpd.rangePush("", "scheduler", "")
+ logger.info("rpd is enabled inside scheduler profiling")
def stop_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
- self.profiler.stop()
- self.profiler.export_chrome_trace(
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
- )
+ #self.profiler.stop()
+ #self.profiler.export_chrome_trace(
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
+ #)
+ if self.tp_rank ==0 or self.tp_rank ==1:
+ self.rpd.rangePop()
+ self.rpd.stop()
+ self.rpd.flush()
+ logger.info("rpd is done inside scheduler")
logger.info("Profiler is done")
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 2621ccd..181df85 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_generation_model, is_multimodal_model
+from rpdTracerControl import rpdTracerControl
+rpdTracerControl.skipCreate()
+
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger(__name__)
@@ -514,10 +518,20 @@ class TokenizerManager:
self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
+ rpd = rpdTracerControl()
+ rpd.setPythonTrace(True)
+ rpd.start()
+ rpd.rangePush("", "tokenizer_manager", "")
+ logger.info("tokenizer_manager rpd profiling started!")
req = ProfileReq.START_PROFILE
self.send_to_scheduler.send_pyobj(req)
def stop_profile(self):
+ rpd = rpdTracerControl()
+ rpd.rangePop()
+ rpd.stop()
+ rpd.flush()
+ logger.info("rpd profiling is done inside tokenizer_manager!")
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index 7111c93..2bd722c 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -30,6 +30,8 @@ import threading
import time
from http import HTTPStatus
from typing import Dict, List, Optional, Union
+from rpdTracerControl import rpdTracerControl
+rpdTracerControl.skipCreate()
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -152,6 +154,11 @@ async def flush_cache():
@app.post("/start_profile")
async def start_profile():
"""Start profiling."""
+ rpd = rpdTracerControl()
+ rpd.setPythonTrace(True)
+ rpd.start()
+ rpd.rangePush("", "server rpd profile range", "")
+ logger.info("rpd profiling started in server.py!")
tokenizer_manager.start_profile()
return Response(
content="Start profiling.\n",
@@ -164,6 +171,11 @@ async def start_profile():
async def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
+ rpd = rpdTracerControl()
+ rpd.rangePop()
+ rpd.stop()
+ rpd.flush()
+ logger.info("rpd profiling is done in server.py!")
return Response(
content="Stop profiling. This will take some time.\n",
status_code=200,

20
3rdparty/amd/profiling/server.sh vendored Executable file
View File

@@ -0,0 +1,20 @@
#!/bin/bash
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
# Get the current timestamp
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
# Define the log file with a timestamp
LOGFILE="sglang_server_log_$TIMESTAMP.json"
# Run the Python command and save the output to the log file
loadTracer.sh python3 -m sglang.launch_server \
--model-path /sgl-workspace/sglang/dummy_grok1 \
--tokenizer-path Xenova/grok-1-tokenizer \
--load-format dummy \
--quantization fp8 \
--tp 8 \
--port 30000 \
--disable-radix-cache 2>&1 | tee "$LOGFILE"

View File

@@ -0,0 +1,25 @@
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 62d1ff9..6ecd78c 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -240,7 +240,6 @@ class Scheduler:
)
self.profiler = torch.profiler.profile(
activities=[
- torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
@@ -1033,9 +1032,11 @@ class Scheduler:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
- self.profiler.export_chrome_trace(
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
- )
+ if self.tp_rank == 0:
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
+ print("Profiling stats done.")
+
logger.info("Profiler is done")

118
3rdparty/amd/tuning/TUNING.md vendored Normal file
View File

@@ -0,0 +1,118 @@
## Tuning SGLang Infer System with AMD GPUs
This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
Three primary runtime areas are covered:
## 1. Triton Kernels
To maximize Triton kernel efficiency, several strategies can be employed:
### Key Environment Variables:
- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
```python
@triton.autotune(configs=[
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
@triton.jit
def _triton_kernel_funtion():
...
```
## 2. Torch Tunable Operations
**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
### Key Environment Variables:
1. **PYTORCH_TUNABLEOP_ENABLED**:
- Default: `0`
- Set to `1` to enable TunableOp.
2. **PYTORCH_TUNABLEOP_TUNING**:
- Default: `1`
- Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
3. **PYTORCH_TUNABLEOP_VERBOSE**:
- Default: `0`
- Set to `1` to enable verbose output for TunableOp.
### Usage Example:
To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
```bash
#Tuning
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
#Inference with tuning op
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
#Print out the log
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
```
## 3. Torch Compilation
The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
### Key Configurations:
1. **Max Autotune**:
- Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
2. **Fine-Grained Control**:
- Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
- Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
3. **Backend Selection**:
- Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
4. **Freezing for Inference**:
- Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
5. **Debugging**:
- Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
### Example Code Block:
```bash
#Gemm Tuning
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
#Specify your backend to TRITON for Gemm Tuning
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
#Inference with large improvement on AMD GPU
TORCHINDUCTOR_FREEZING=1 your_script.sh
```
## 4. Fused MOE kernel
To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration
### Key parameters:
- **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers
- **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly
- **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch
- **--dtype**: computation type
```bash
#Tuning
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input length 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
#so we can tune decode moe use below command
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
# and use this command to tune prefill moe
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768"
```
## Reference
For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)

View File

@@ -0,0 +1,380 @@
import argparse
import json
import os
import sys
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from tqdm import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_file_name,
)
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
def main(model, tp_size, dtype: str, batches):
method = fused_moe
for bs in batches:
run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
def prune_configs(M, N, K, configs):
pruned_configs = []
elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
mfma = 16 if M < 32 or N < 32 else 32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True
for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
# kpack = config.get("kpack")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = 1 # config.get("SPLIT_K")
GROUP_M = config.get("GROUP_SIZE_M")
if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
continue
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
continue
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue
pruned_configs.append(config)
return pruned_configs
def union_of_list_of_dicts(l1, l2):
result = []
temp_list = l1.copy()
temp_list.extend(l2)
for myDict in temp_list:
if myDict not in result:
result.append(myDict)
return result
def run_grid(bs, model, method, tp_size, dtype: str):
config = AutoConfig.from_pretrained(model)
top_k = config.num_experts_per_tok
d_model = config.hidden_size
model_intermediate_size = config.intermediate_size
num_layers = config.num_hidden_layers
hidden_states_dtype = config.torch_dtype
if config.num_experts_per_tok:
if config.architectures[0] == "Grok1ModelForCausalLM":
num_total_experts = config.num_experts
else:
num_total_experts = config.num_local_experts
else:
raise ValueError(f"Unsupported Mixtral model {model}")
# tp_size = 2
num_warmup_calls = 10
num_calls = 30
num_warmup_trials = 1
num_trials = 1
full_configs = []
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [16, 32, 64, 128, 256]
block_k_range = [32, 64, 128, 256] # MUST >= 32
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [2]
waves_per_eu_range = [0, 1, 2, 4, 8]
# Remove 32 because of triton compiling error
matrix_instr_nonkdim_range = [16]
kpack_range = [1, 2]
for block_size_m in block_m_range:
for block_size_n in block_n_range:
for block_size_k in block_k_range:
for group_size_m in group_m_range:
for num_warps in num_warps_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
full_configs.append(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"kpack": kpack,
}
)
M1 = bs * 2
N1 = model_intermediate_size * 2 // tp_size
K1 = d_model
prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
M2 = bs * 2
N2 = d_model
K2 = model_intermediate_size // tp_size
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
print(
f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
{len(prune_configs_2)=} | {len(configs)=}"
)
best_config = None
best_time_us = 1e20
print(f"{tp_size=} {bs=}")
for config in tqdm(configs):
# warmup
try:
print(config)
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_warmup_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
hidden_states_dtype=hidden_states_dtype,
)
except triton.runtime.autotuner.OutOfResources:
continue
# trial
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
hidden_states_dtype=hidden_states_dtype,
)
kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers
if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us
tqdm.write(
f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
f"{d_model=} {model_intermediate_size=} {num_layers=}"
)
print("best_time_us", best_time_us)
print("best_config", best_config)
# holds Dict[str, Dict[str, int]]
filename = get_config_file_name(
num_total_experts,
model_intermediate_size // tp_size,
"float8" if dtype == "float8" else None,
)
print(f"writing config to file {filename}")
existing_content = {}
if os.path.exists(filename):
with open(filename, "r") as f:
existing_content = json.load(f)
existing_content[str(bs)] = best_config
with open(filename, "w") as f:
json.dump(existing_content, f, indent=4)
f.write("\n")
def run_timing(
num_calls: int,
bs: int,
d_model: int,
num_total_experts: int,
top_k: int,
tp_size: int,
model_intermediate_size: int,
method,
config,
dtype: str,
hidden_states_dtype,
) -> float:
shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=hidden_states_dtype,
)
w1 = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w2 = torch.rand(
(num_total_experts, d_model, shard_intermediate_size + padding_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if dtype == "float8":
w1 = w1.to(torch.float8_e4m3fnuz)
w2 = w2.to(torch.float8_e4m3fnuz)
w1_scale = torch.ones(
num_total_experts, device=hidden_states.device, dtype=torch.float32
)
w2_scale = torch.ones(
num_total_experts, device=hidden_states.device, dtype=torch.float32
)
a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
gating_output = F.softmax(
torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1,
)
##################################
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
gating_output=gating_output[0],
topk=top_k,
renormalize=True,
inplace=True,
override_config=config,
use_fp8=dtype == "float8",
)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="benchmark_mixtral_moe",
description="Benchmark and tune the fused_moe kernel",
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["float8", "float16", "bfloat16"],
help="Data type used for fused_moe kernel computations",
)
parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
parser.add_argument("-b", "--batches", type=str)
args = parser.parse_args()
batches = args.batches.split(",")
sys.exit(main(args.model, args.tp_size, args.dtype, batches))

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

46
Makefile Normal file
View File

@@ -0,0 +1,46 @@
.PHONY: check-deps install-deps format update help
# Show help for each target
help:
@echo "Available targets:"
@grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
check-deps: ## Check and install required Python formatting dependencies
@command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort)
@command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black)
install-deps: ## Install Python formatting tools (isort and black)
pip install isort black
format: check-deps ## Format modified Python files using isort and black
@echo "Formatting modified Python files..."
git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}'
FILES_TO_UPDATE = docker/Dockerfile.rocm \
python/pyproject.toml \
python/sglang/version.py \
docs/developer_guide/setup_github_runner.md \
docs/get_started/install.md \
docs/platforms/amd_gpu.md \
docs/platforms/ascend_npu.md \
benchmark/deepseek_v3/README.md
update: ## Update version numbers across project files. Usage: make update <new_version>
@if [ -z "$(filter-out $@,$(MAKECMDGOALS))" ]; then \
echo "Version required. Usage: make update <new_version>"; \
exit 1; \
fi
@OLD_VERSION=$$(grep "version" python/sglang/version.py | cut -d '"' -f2); \
NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \
echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \
for file in $(FILES_TO_UPDATE); do \
if [ "$(shell uname)" = "Darwin" ]; then \
sed -i '' -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \
else \
sed -i -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \
fi \
done; \
echo "Version update complete"
%:
@:

78
README.md Normal file
View File

@@ -0,0 +1,78 @@
<div align="center" id="sglangtop">
<img src="https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" alt="logo" width="400" margin="10px"></img>
[![PyPI](https://img.shields.io/pypi/v/sglang)](https://pypi.org/project/sglang)
![PyPI - Downloads](https://static.pepy.tech/badge/sglang?period=month)
[![license](https://img.shields.io/github/license/sgl-project/sglang.svg)](https://github.com/sgl-project/sglang/tree/main/LICENSE)
[![issue resolution](https://img.shields.io/github/issues-closed-raw/sgl-project/sglang)](https://github.com/sgl-project/sglang/issues)
[![open issues](https://img.shields.io/github/issues-raw/sgl-project/sglang)](https://github.com/sgl-project/sglang/issues)
[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/sgl-project/sglang)
</div>
--------------------------------------------------------------------------------
| [**Blog**](https://lmsys.org/blog/2025-05-05-large-scale-ep/)
| [**Documentation**](https://docs.sglang.ai/)
| [**Join Slack**](https://slack.sglang.ai/)
| [**Join Bi-Weekly Development Meeting**](https://meeting.sglang.ai/)
| [**Roadmap**](https://github.com/sgl-project/sglang/issues/7736)
| [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
## News
- [2025/08] 🔔 SGLang x AMD SF Meetup on 8/22: Hands-on GPU workshop, tech talks by AMD/xAI/SGLang, and networking ([Roadmap](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_roadmap.pdf), [Large-scale EP](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_ep.pdf), [Highlights](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_highlights.pdf), [AITER/MoRI](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_aiter_mori.pdf), [Wave](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_wave.pdf)).
- [2025/08] 🔥 SGLang provides day-0 support for OpenAI gpt-oss model ([instructions](https://github.com/sgl-project/sglang/issues/8833))
- [2025/06] 🔥 SGLang, the high-performance serving infrastructure powering trillions of tokens daily, has been awarded the third batch of the Open Source AI Grant by a16z ([a16z blog](https://a16z.com/advancing-open-source-ai-through-benchmarks-and-bold-experimentation/)).
- [2025/06] 🔥 Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP (Part I): 2.7x Higher Decoding Throughput ([blog](https://lmsys.org/blog/2025-06-16-gb200-part-1/)).
- [2025/05] 🔥 Deploying DeepSeek with PD Disaggregation and Large-scale Expert Parallelism on 96 H100 GPUs ([blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/)).
- [2025/03] Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html))
- [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/))
- [2024/12] v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
<details>
<summary>More</summary>
- [2025/02] Unlock DeepSeek-R1 Inference Performance on AMD Instinct™ MI300X GPU ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1_Perf/README.html))
- [2025/01] SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412))
- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
</details>
## About
SGLang is a fast serving framework for large language models and vision language models.
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
The core features include:
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, zero-overhead CPU scheduler, prefill-decode disaggregation, speculative decoding, continuous batching, paged attention, tensor/pipeline/expert/data parallelism, structured outputs, chunked prefill, quantization (FP4/FP8/INT4/AWQ/GPTQ), and multi-lora batching.
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Qwen, DeepSeek, Kimi, GPT, Gemma, Mistral, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with wide industry adoption.
## Getting Started
- [Install SGLang](https://docs.sglang.ai/get_started/install.html)
- [Quick Start](https://docs.sglang.ai/basic_usage/send_request.html)
- [Backend Tutorial](https://docs.sglang.ai/basic_usage/openai_api_completions.html)
- [Frontend Tutorial](https://docs.sglang.ai/references/frontend/frontend_tutorial.html)
- [Contribution Guide](https://docs.sglang.ai/developer_guide/contribution_guide.html)
## Benchmark and Performance
Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), [Large-scale expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/).
## Roadmap
[Development Roadmap (2025 H2)](https://github.com/sgl-project/sglang/issues/7736)
## Adoption and Sponsorship
SGLang has been deployed at large scale, generating trillions of tokens in production each day. It is trusted and adopted by a wide range of leading enterprises and institutions, including xAI, AMD, NVIDIA, Intel, LinkedIn, Cursor, Oracle Cloud, Google Cloud, Microsoft Azure, AWS, Atlas Cloud, Voltage Park, Nebius, DataCrunch, Novita, InnoMatrix, MIT, UCLA, the University of Washington, Stanford, UC Berkeley, Tsinghua University, Jam & Tea Studios, Baseten, and other major technology organizations across North America and Asia. As an open-source LLM inference engine, SGLang has become the de facto industry standard, with deployments running on over 1,000,000 GPUs worldwide.
<img src="https://raw.githubusercontent.com/sgl-project/sgl-learning-materials/refs/heads/main/slides/adoption.png" alt="logo" width="800" margin="10px"></img>
## Contact Us
For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai.
## Acknowledgment
We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).

BIN
assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 393 KiB

1
assets/logo.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 65 KiB

BIN
assets/logo_square.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

1
assets/logo_square.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 51 KiB

View File

@@ -0,0 +1,250 @@
import argparse
import torch
import triton
from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd_grouped,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
# gpt oss
head_num = 64
head_dim = 64
head_kv_num = 8
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["S"], # sequence length on x-axis
x_vals=[128, 256, 512, 1024, 2048, 4096],
x_log=True,
line_arg="B", # batch size as different lines
line_vals=[1, 8, 32, 128],
line_names=["B=1", "B=8", "B=32", "B=128"],
styles=[
("blue", "-"),
("green", "-"),
("red", "-"),
("cyan", "-"),
],
ylabel="TFLOPS",
plot_name="attention-sink-triton-decode",
args={},
)
)
def benchmark_decode(B, S, H_Q, H_KV, D):
D_V = D
dtype = torch.bfloat16
seq_len = S
total_tokens = B * seq_len
device = torch.device("cuda")
sm_scale = 1.0 / (D**0.5)
max_kv_splits = 8
num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda")
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len, dim=0)
kv_indices = torch.arange(total_tokens, device="cuda")
attn_logits1 = torch.empty(
(B, H_Q, max_kv_splits, D_V),
dtype=torch.float32,
device="cuda",
)
attn_lse1 = torch.empty(
(B, H_Q, max_kv_splits, D_V),
dtype=torch.float32,
device="cuda",
)
sink = torch.randn(H_Q, device=device, dtype=torch.float32)
# warmup
for _ in range(5):
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits1,
attn_lse1,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=0.0,
sinks=sink,
)
# benchmark
run_step = 500
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(run_step):
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
kv_indptr,
kv_indices,
attn_logits1,
attn_lse1,
num_kv_splits,
max_kv_splits,
sm_scale,
logit_cap=0.0,
sinks=sink,
)
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
ms = start_event.elapsed_time(end_event) / run_step
tflops = lambda ms: (2 * B * S * H_Q * D) * 1e-9 / ms # must be causal
return tflops(ms)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["S"], # sequence length on x-axis
x_vals=[128, 256, 512, 1024, 2048, 4096],
x_log=True,
line_arg="B", # batch size as different lines
line_vals=[1, 8, 32, 128],
line_names=["B=1", "B=8", "B=32", "B=128"],
styles=[
("blue", "-"),
("green", "-"),
("red", "-"),
("cyan", "-"),
],
ylabel="TFLOPS",
plot_name="attention-sink-triton-extend",
args={},
)
)
def benchmark_extend(B, S, H_Q, H_KV, D):
# S here represents N_CTX from the test
dtype = torch.bfloat16
device = "cuda"
# Split S into prefix and extend lengths
prefill_len = S // 2 # Similar to test's N_CTX // 2
extend_len = S // 4 # Make extend length smaller than prefix
# Calculate total tokens and extend tokens
total_extend_tokens = B * extend_len
total_prefix_tokens = B * prefill_len
# Create query, key, value tensors for extension
q_extend = torch.randn(total_extend_tokens, H_Q, D, dtype=dtype, device=device)
k_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
v_extend = torch.randn(total_extend_tokens, H_KV, D, dtype=dtype, device=device)
o_extend = torch.empty_like(q_extend)
# Create key-value buffers for prefix
k_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
v_buffer = torch.randn(total_prefix_tokens, H_KV, D, dtype=dtype, device=device)
# Create index pointers
qo_indptr = torch.arange(0, (B + 1) * extend_len, extend_len, device=device).to(
torch.int32
)
kv_indptr = torch.arange(0, (B + 1) * prefill_len, prefill_len, device=device).to(
torch.int32
)
kv_indices = torch.arange(0, total_prefix_tokens, device=device).to(torch.int32)
sm_scale = 1.0 / (D**0.5)
# sliding_window = 128 # From GPT-OSS config, skip for now
sliding_window = -1
sink = torch.randn(H_Q, device=device, dtype=torch.float32)
# warmup
for _ in range(5):
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=extend_len,
sm_scale=sm_scale,
sliding_window_size=sliding_window,
sinks=sink,
)
# benchmark
run_step = 500
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(run_step):
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=extend_len,
sm_scale=sm_scale,
sliding_window_size=sliding_window,
sinks=sink,
)
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
ms = start_event.elapsed_time(end_event) / run_step
# FLOPS calculation: each attention operation requires 2 multiplications per element
total_flops = 2 * total_extend_tokens * H_Q * (prefill_len + extend_len / 2) * D
tflops = lambda ms: total_flops * 1e-12 / (ms * 1e-3) # convert to TFLOPS
return tflops(ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--bench", type=str, default="all", help="all, extend, decode")
args = parser.parse_args()
kwargs = {
"H_Q": head_num,
"H_KV": head_kv_num,
"D": head_dim,
}
if args.bench in ["all", "decode"]:
benchmark_decode.run(print_data=True, show_plots=False, **kwargs)
if args.bench in ["all", "extend"]:
benchmark_extend.run(print_data=True, show_plots=False, **kwargs)
print("Benchmark finished!")

View File

@@ -0,0 +1,130 @@
# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance.
#
# Launch a server:
# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning
import random
import string
import time
from tqdm import tqdm
from transformers import AutoTokenizer
import sglang as sgl
from sglang import set_default_backend
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
def generate_random_string(token_length: int) -> str:
random_string = "".join(
random.choices(string.ascii_letters + string.digits, k=token_length * 100)
)
tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[
:token_length
]
if len(tokenized_output) < token_length:
tokenized_output = tokenized_output + [tokenizer.pad_token_id] * (
token_length - len(tokenized_output)
)
decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False)
return decoded_string
def generate_unique_prefix(base_text, index):
return str(index) + base_text[len(str(index)) :]
@sgl.function
def text_qa(s, question, gen_len):
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n", temperature=0, max_tokens=gen_len)
def prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length):
base_prefix = generate_random_string(prefix_length)
tot_input_len = 0
all_prompts = []
for i in tqdm(range(num_prefix), desc="prepare prompts"):
unique_prefix = generate_unique_prefix(base_prefix, i)
prompt_list = []
for j in range(num_samples_per_prefix):
suffix = generate_random_string(suffix_length)
prompt = unique_prefix + suffix
prompt_list.append(prompt)
tot_input_len += len(tokenizer.encode(prompt))
all_prompts.append(prompt_list)
return all_prompts, tot_input_len
def test_batch_by_batch(all_prompts, gen_len):
backend.flush_cache()
tot_time = 0
for i in range(len(all_prompts)):
tic = time.perf_counter()
text_qa.run_batch(
list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))),
)
tot_time += time.perf_counter() - tic
return tot_time
def test_batch_by_batch_with_hint(all_prompts, gen_len):
backend.flush_cache()
tot_time = 0
for i in range(len(all_prompts)):
tic = time.perf_counter()
# Send a hint to cache the prefix
text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len])))
# Send the batch
text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))))
tot_time += time.perf_counter() - tic
return tot_time
def test_send_all(all_prompts, gen_len):
backend.flush_cache()
all_prompts = [x for prompt_list in all_prompts for x in prompt_list]
tic = time.perf_counter()
text_qa.run_batch(
list(zip(all_prompts, [gen_len] * len(all_prompts))),
)
tot_time = time.perf_counter() - tic
return tot_time
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
backend = RuntimeEndpoint("http://127.0.0.1:30000")
set_default_backend(backend)
random.seed(0)
num_prefix = 10
num_samples_per_prefix = 32
prefix_length = 1024
suffix_length = 128
gen_len = 1
all_prompts, tot_input_len = prepare_prompts(
num_prefix, num_samples_per_prefix, prefix_length, suffix_length
)
print(f"Total input token length: {tot_input_len}\n")
cost = test_batch_by_batch(all_prompts, gen_len)
print(f"Latency of test_batch_by_batch : {cost:.4f} s\n")
cost = test_batch_by_batch_with_hint(all_prompts, gen_len)
print(f"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\n")
cost = test_send_all(all_prompts, gen_len)
print(f"Latency of test_send_all : {cost:.4f} s\n")

View File

@@ -0,0 +1,193 @@
import concurrent.futures
import os
import random
import time
from concurrent.futures import ProcessPoolExecutor
from statistics import mean
import requests
from tqdm import tqdm
from transformers import AutoTokenizer
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
###############################################################################
# CONFIG
###############################################################################
ENDPOINT_URL = "http://127.0.0.1:30000"
TOKENIZER_DIR = "/models/meta-llama/Llama-3.2-3B"
# Benchmark configurations
NUM_REQUESTS = 10 # Total number of requests (each with BATCH_SIZE prompts)
NUM_TOKENS = 32000 # Tokens per prompt
BATCH_SIZE = 8 # Number of prompts per request
GEN_TOKENS = 0 # Tokens to generate per prompt
###############################################################################
# REQUEST GENERATION (in parallel)
###############################################################################
def generate_random_prompt(index, tokenizer_dir, num_tokens):
"""Generate a single random prompt with specified token count."""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
vocab_size = tokenizer.vocab_size
def generate_random_text(num_toks):
random_token_ids = [random.randint(0, vocab_size - 1) for _ in range(num_toks)]
return tokenizer.decode(random_token_ids, clean_up_tokenization_spaces=True)
random_text = generate_random_text(num_tokens)
return f"Prompt {index}: {random_text}"
def prepare_all_prompts(num_requests, batch_size, num_tokens, tokenizer_dir):
"""Generate prompts for all requests in parallel."""
total_prompts = num_requests * batch_size
all_prompts = [None] * total_prompts
max_workers = min(os.cpu_count() or 1, total_prompts)
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(generate_random_prompt, i, tokenizer_dir, num_tokens)
for i in range(total_prompts)
]
for future in tqdm(
concurrent.futures.as_completed(futures),
total=total_prompts,
desc="Generating prompts",
):
index = futures.index(future)
all_prompts[index] = future.result()
batched_prompts = [
all_prompts[i * batch_size : (i + 1) * batch_size] for i in range(num_requests)
]
print(
f"Generated {total_prompts} prompts with {num_tokens} tokens each, grouped into {num_requests} requests of {batch_size} prompts.\n"
)
return batched_prompts
###############################################################################
# HTTP CALLS
###############################################################################
def send_batch_request(endpoint, prompts, gen_tokens, request_id):
"""Send a batch of prompts to the /generate endpoint synchronously."""
sampling_params = {
"max_new_tokens": gen_tokens,
"temperature": 0.7,
"stop": "\n",
}
data = {"text": prompts, "sampling_params": sampling_params}
start_time = time.perf_counter()
try:
response = requests.post(
endpoint.base_url + "/generate", json=data, timeout=3600
)
if response.status_code != 200:
error = response.json()
raise RuntimeError(f"Request {request_id} failed: {error}")
result = response.json()
elapsed_time = (time.perf_counter() - start_time) * 1000 # Convert to ms
avg_per_prompt = elapsed_time / len(prompts) if prompts else 0
return request_id, elapsed_time, avg_per_prompt, True, len(prompts)
except Exception as e:
print(f"[Request] Error for request {request_id}: {e}")
return request_id, 0, 0, False, len(prompts)
def run_benchmark(endpoint, batched_prompts, batch_size, gen_tokens):
"""Run the benchmark sequentially."""
results = []
num_requests = len(batched_prompts)
# Record start time for total latency
benchmark_start_time = time.perf_counter()
for i, batch_prompts in enumerate(batched_prompts):
request_id = i + 1
assert (
len(batch_prompts) == batch_size
), f"Request {request_id} should have {batch_size} prompts, got {len(batch_prompts)}"
print(
f"[Request] Sending request {request_id}/{num_requests} with {len(batch_prompts)} prompts at {int(time.time()*1000)}"
)
result = send_batch_request(endpoint, batch_prompts, gen_tokens, request_id)
results.append(result)
# Calculate total latency
total_latency = (time.perf_counter() - benchmark_start_time) * 1000 # Convert to ms
return results, total_latency
###############################################################################
# RESULTS
###############################################################################
def process_results(results, total_latency, num_requests):
"""Process and display benchmark results."""
total_time = 0
successful_requests = 0
failed_requests = 0
request_latencies = []
per_prompt_latencies = []
total_prompts = 0
for request_id, elapsed_time, avg_per_prompt, success, batch_size in results:
if success:
successful_requests += 1
total_prompts += batch_size
request_latencies.append(elapsed_time)
per_prompt_latencies.append(avg_per_prompt)
total_time += elapsed_time / 1000 # Convert to seconds
else:
failed_requests += 1
avg_request_latency = mean(request_latencies) if request_latencies else 0
avg_per_prompt_latency = mean(per_prompt_latencies) if per_prompt_latencies else 0
throughput = total_prompts / total_time if total_time > 0 else 0
print("\nBenchmark Summary:")
print(f" Total requests sent: {len(results)}")
print(f" Total prompts sent: {total_prompts}")
print(f" Successful requests: {successful_requests}")
print(f" Failed requests: {failed_requests}")
print(f" Total latency (all requests): {total_latency:.2f} ms")
print(f" Avg per request latency: {avg_request_latency:.2f} ms")
print(f" Avg per prompt latency: {avg_per_prompt_latency:.2f} ms")
print(f" Throughput: {throughput:.2f} prompts/second\n")
###############################################################################
# MAIN
###############################################################################
def main():
# Initialize endpoint
endpoint = RuntimeEndpoint(ENDPOINT_URL)
# Generate prompts
batched_prompts = prepare_all_prompts(
NUM_REQUESTS, BATCH_SIZE, NUM_TOKENS, TOKENIZER_DIR
)
# Flush cache before benchmark
# endpoint.flush_cache()
# Run benchmark
print(
f"Starting benchmark: NUM_TOKENS={NUM_TOKENS}, BATCH_SIZE={BATCH_SIZE}, NUM_REQUESTS={NUM_REQUESTS}\n"
)
results, total_latency = run_benchmark(
endpoint, batched_prompts, BATCH_SIZE, GEN_TOKENS
)
# Process and display results
process_results(results, total_latency, NUM_REQUESTS)
if __name__ == "__main__":
random.seed(0)
main()

View File

@@ -0,0 +1,126 @@
import random
import time
from statistics import mean
from transformers import AutoTokenizer
# CONFIG
TOKENIZER_DIR = (
"/shared/public/sharing/fait360brew/training/models/meta-llama/Llama-3.2-3B"
)
NUM_TOKENS = 20000 # Each prompt should contain this many tokens
BATCH_SIZES = [1, 2, 4, 8] # Test different batch sizes
NUM_RUNS = 5 # Number of runs for each batch size to get reliable measurements
def generate_random_prompts(num_prompts, num_tokens, tokenizer):
"""Generate random prompts with specified token count."""
vocab_size = tokenizer.vocab_size
all_prompts = []
print(f"Generating {num_prompts} random prompts with {num_tokens} tokens each...")
for i in range(num_prompts):
# Generate random token IDs - this directly gives us the exact token count
random_token_ids = [
random.randint(0, vocab_size - 1) for _ in range(num_tokens)
]
random_text = tokenizer.decode(
random_token_ids, clean_up_tokenization_spaces=True
)
prompt = f"Prompt {i}: {random_text}"
tokens = tokenizer.encode(prompt)
print(f" Prompt {i}: {len(tokens)} tokens")
all_prompts.append(prompt)
return all_prompts
def benchmark_sequential_vs_batch(prompts, batch_size, tokenizer):
"""Compare sequential vs batch tokenization for a given batch size."""
# Sequential tokenization using encode()
sequential_times = []
for run in range(NUM_RUNS):
batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison
start_time = time.perf_counter()
for prompt in batch_prompts:
tokens = tokenizer.encode(prompt)
sequential_time = (time.perf_counter() - start_time) * 1000
sequential_times.append(sequential_time)
# Batch tokenization using tokenizer()
batch_times = []
for run in range(NUM_RUNS):
batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison
start_time = time.perf_counter()
tokens = tokenizer(batch_prompts)
batch_time = (time.perf_counter() - start_time) * 1000
batch_times.append(batch_time)
return {
"batch_size": batch_size,
"avg_sequential_ms": mean(sequential_times),
"avg_batch_ms": mean(batch_times),
"speedup_factor": (
mean(sequential_times) / mean(batch_times) if mean(batch_times) > 0 else 0
),
"sequential_runs": sequential_times,
"batch_runs": batch_times,
}
def main():
print("Tokenizer Benchmark: Sequential vs Batch Processing")
print("-" * 60)
print(f"Tokenizer: {TOKENIZER_DIR}")
print(f"Tokens per prompt: {NUM_TOKENS}")
print(f"Number of runs per batch size: {NUM_RUNS}")
print("-" * 60)
# Load tokenizer once for all operations
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
# The largest batch size determines how many prompts we need
max_batch_size = max(BATCH_SIZES)
all_prompts = generate_random_prompts(max_batch_size, NUM_TOKENS, tokenizer)
results = []
print("\nRunning benchmark...")
for batch_size in BATCH_SIZES:
print(f"\nBenchmarking batch size: {batch_size}")
result = benchmark_sequential_vs_batch(all_prompts, batch_size, tokenizer)
results.append(result)
print(f" Sequential tokenization (encode):")
for i, run_time in enumerate(result["sequential_runs"]):
print(f" Run {i+1}: {run_time:.2f} ms")
print(f" Average: {result['avg_sequential_ms']:.2f} ms")
print(f" Batch tokenization (tokenizer):")
for i, run_time in enumerate(result["batch_runs"]):
print(f" Run {i+1}: {run_time:.2f} ms")
print(f" Average: {result['avg_batch_ms']:.2f} ms")
print(f" Speedup factor: {result['speedup_factor']:.2f}x")
print("\n" + "=" * 60)
print("SUMMARY OF RESULTS")
print("=" * 60)
print(
f"{'Batch Size':<10} {'Sequential (ms)':<18} {'Batch (ms)':<18} {'Speedup':<10}"
)
print("-" * 60)
for result in results:
print(
f"{result['batch_size']:<10} {result['avg_sequential_ms']:.2f} ms{' ' * 8} {result['avg_batch_ms']:.2f} ms{' ' * 8} {result['speedup_factor']:.2f}x"
)
if __name__ == "__main__":
random.seed(0)
main()

View File

@@ -0,0 +1,89 @@
## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0
In short, with multi step enabled, in online scenarios that we benchmarked, the Median TTFT of vLLM is **3 times** that of SGLang, and the Median ITL is **10 times** that of SGLang. Lower Median TTFT and ITL are better. vLLM's multi-step optimization did not improve throughput while ensuring lower Median TTFT and ITL. Also, under maximum throughput benchmark, if vLLM does not set gpu util to 0.95 separately and uses the default configuration instead, its maximum throughput is **lower** than that of SGLang.
## Online benchmark results
### Llama 3.1 8B Instruct 1 x A100 80G
| RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
|------|-------------|--------|--------------------|-------------|-------------|------------|
| 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** |
| 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** |
| 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** |
| 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** |
### Llama 3.1 70B Insruct 4 x H100 80G
| RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL |
|------|-------------|--------|--------------------|-------------|-------------|------------|
| 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** |
| 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** |
| 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** |
| 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** |
## Offline benchmark results
### Llama 3.1 8B Instruct 1 x A100 80G
| RPS | Num Prompts | Engine | Request throughput | Output token throughput |
|------|-------------|--------|--------------------|-------------------------|
| inf | 5000 | SGLang | 22.03 | **4281.51** |
| inf | 5000 | vLLM | 21.27 | **4132.37** |
### Llama 3.1 70B Insruct 4 x H100 80G
| RPS | Num Prompts | Engine | Request throughput | Output token throughput |
|------|-------------|--------|--------------------|-------------------------|
| inf | 5000 | SGLang | 19.84 | **3856.01** |
| inf | 5000 | vLLM | 19.04 | **3700.64** |
## Installation
```bash
# install sglang v0.3.0
pip install --upgrade pip
pip install "sglang[all]"==0.3.0
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
# install vllm v0.6.0
pip install vllm==0.6.0
```
## Notes
We referred to the reproduction method in https://github.com/vllm-project/vllm/issues/8176, and added the `--num-scheduler-steps 10` parameter when starting the vLLM server. The `gpu_memory_utilization` of vLLM is by default 0.9 at both TP 1 and TP 4, while SGLang's `mem_frac` is 0.88 at TP 1 and 0.85 at TP 4, so we manually set it to 0.88 at TP 4.
## Online benchmarks
```bash
# Llama 3.1 8B Instruct on 1 x A100
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
# Llama 3.1 70B Instruct on 4 x H100
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
# bench serving
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 1200 --request-rate 4
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 2400 --request-rate 8
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 1200 --request-rate 4
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 2400 --request-rate 8
```
## Offline benchmarks
```bash
# Llama 3.1 8B Instruct on 1 x A100
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096
# Llama 3.1 70B Instruct on 4 x H100
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 --mem-frac 0.88
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096
# bench serving
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000
python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000
```

View File

@@ -0,0 +1,24 @@
# Create dummy weights:
# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
# 2. Get `config.json`` from ./config.md
# 3. Download the tokenizer
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
# Launch sglang
# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87
# offline
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > sglang_log12
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > sglang_log13
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > sglang_log14
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > sglang_log15
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompt 2000 > sglang_log21
# online
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > sglang_log31
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > sglang_log32
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > sglang_log33
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > sglang_log34
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > sglang_log35

View File

@@ -0,0 +1,17 @@
# Launch trtllm
# https://github.com/sgl-project/tensorrt-demo
# offline
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log11
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log12
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log13
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log14
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log15
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name sharegpt --num-prompt 2000 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log21
# online
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log31
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log32
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log33
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log34
python3 ../../python/sglang/bench_serving.py --backend trt --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 --model /root/Meta-Llama-3-8B-Instruct > trtllm_log35

View File

@@ -0,0 +1,24 @@
# Create dummy weights:
# 1. Create a folder `~/llama-3.1-405b-fp8-dummy` and create `config.json` and tokenizer under this folder.
# 2. Get `config.json`` from ./config.md
# 3. Download the tokenizer
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer.json
# wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json
# Launch vllm
# python3 -m vllm.entrypoints.openai.api_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --disable-log-requests --tensor-parallel-size 8 --max-model-len 10000
# offline
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > vllm_log11
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 4000 --random-input 1024 --random-output 512 > vllm_log12
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 800 --random-input 4096 --random-output 2048 > vllm_log13
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1500 --random-input 4096 --random-output 1024 > vllm_log14
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 6000 --random-input 256 --random-output 512 > vllm_log15
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name sharegpt --num-prompt 2000 > vllm_log21
# online
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 300 --request-rate 1 --random-input 1024 --random-output 1024 > vllm_log31
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 600 --request-rate 2 --random-input 1024 --random-output 1024 > vllm_log32
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 1200 --request-rate 4 --random-input 1024 --random-output 1024 > vllm_log33
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 2400 --request-rate 8 --random-input 1024 --random-output 1024 > vllm_log34
python3 ../../python/sglang/bench_serving.py --backend vllm --dataset-name random --num-prompt 3200 --request-rate 16 --random-input 1024 --random-output 1024 > vllm_log35

View File

@@ -0,0 +1,164 @@
# How to reproduce the benchmark results of SGLang
## Prerequisite
### Install the latest SGLang
```bash
git clone https://github.com/sgl-project/sglang.git
cd sglang
git checkout v0.2.7
pip install --upgrade pip
pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
```
### Set up ulimit and HF_TOKEN
```bash
ulimit -n 65535
# Change the token to a real and usable one, with access permissions for the Llama 3 models.
export HF_TOKEN=hf_token
```
### Launch the server
```bash
# Meta-Llama-3.1-8B-Instruct
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache
# Meta-Llama-3.1-70B-Instruct
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-70B-Instruct --disable-radix-cache --tp 8
# Meta-Llama-3-70B-Instruct-FP8
python -m sglang.launch_server --model-path neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-radix-cache --tp 8
```
## Benchmark
### Hardware Requirements
- 8B models: Single NVIDIA A100 80GB GPU
- 70B models: 8 x NVIDIA A100 80GB GPUs with Tensor Parallelism (TP) 8
- 70B FP8 models: 8 x NVIDIA H100 GPUs with Tensor Parallelism (TP) 8
Please ensure you have the appropriate hardware before running the benchmarks.
#### Offline benchmark
```bash
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 3000 --output-file offline.jsonl
cat offline.jsonl | cut -d':' -f12 | cut -d',' -f1
```
#### Online benchmark
```bash
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online.jsonl
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online.jsonl
cat online.jsonl | cut -d':' -f9 | cut -d',' -f1
```
## Other
We tried using vLLM 0.5.3.post1, but it often crashes under high loads, and it seems to have similar or worse performance compared to vLLM 0.5.2 from our partial benchmarking, so we are using the older version, vLLM 0.5.2.
Preparation for TensorRT LLM can refer to https://github.com/sgl-project/tensorrt-demo. Specifically, we used a batch size of 512, a max input length of 8192, and a max number of tokens of 8192. The instance count for preprocessing and postprocessing in Triton Server is 16.
```bash
# vLLM
pip install vllm==0.5.2
pip install jsonschema==4.21.1
# Meta-Llama-3-8B-Instruct
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B-Instruct --disable-log-requests
# meta-llama/Meta-Llama-3-70B-Instruct
python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B-Instruct --disable-log-requests --tensor 8
# neuralmagic/Meta-Llama-3-70B-Instruct-FP8
python -m vllm.entrypoints.openai.api_server --model neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --disable-log-requests --tensor 8
```
```bash
wget https://raw.githubusercontent.com/sgl-project/sglang/main/python/sglang/bench_serving.py
```
```bash
# vLLM Offline
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name sharegpt --num-prompts 3000 --output-file offline_vllm.jsonl
cat offline_vllm.jsonl | cut -d':' -f12 | cut -d',' -f1
```
```bash
# vLLM Online
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_vllm.jsonl
python3 bench_serving.py --backend vllm --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_vllm.jsonl
cat online_vllm.jsonl | cut -d':' -f9 | cut -d',' -f1
```
```bash
# TensorRT LLM Offline 8B
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_8b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_8b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_8b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_8b.jsonl
python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_8b.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_8b.jsonl
cat offline_trt_8b.jsonl | cut -d':' -f12 | cut -d',' -f1
```
```bash
# TensorRT LLM Online 8B
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_8b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_8b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_8b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_8b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_8b.jsonl
cat online_trt_8b.jsonl | cut -d':' -f9 | cut -d',' -f1
```
```bash
# TensorRT LLM Offline 70B
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 4000 --random-input 1024 --random-output 1024 --output-file offline_trt_70b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 5000 --random-input 1024 --random-output 512 --output-file offline_trt_70b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 1000 --random-input 4096 --random-output 2048 --output-file offline_trt_70b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --num-prompts 2000 --random-input 4096 --random-output 1024 --output-file offline_trt_70b.jsonl
python3 bench_serving.py --backend trt --dataset-name random --num-prompts 6000 --random-input 256 --random-output 512 --output-file offline_trt_70b.jsonl --model meta-llama/Meta-Llama-3-70B-Instruct
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name sharegpt --num-prompts 3000 --output-file offline_trt_70b.jsonl
cat offline_trt_70b.jsonl | cut -d':' -f12 | cut -d',' -f1
```
```bash
# TensorRT LLM Online 70B
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 300 --request-rate 1 --output-file online_trt_70b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 600 --request-rate 2 --output-file online_trt_70b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 1200 --request-rate 4 --output-file online_trt_70b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 2400 --request-rate 8 --output-file online_trt_70b.jsonl
python3 bench_serving.py --backend trt --model meta-llama/Meta-Llama-3-70B-Instruct --dataset-name random --random-input 1024 --random-output 1024 --num-prompts 3200 --request-rate 16 --output-file online_trt_70b.jsonl
cat online_trt_70b.jsonl | cut -d':' -f9 | cut -d',' -f1
```

View File

@@ -0,0 +1,100 @@
### used for TensorRT LLM
```
{
"architecture": "LlamaForCausalLM",
"dtype": "float16",
"logits_dtype": "float32",
"vocab_size": 128256,
"max_position_embeddings": 8192,
"hidden_size": 16384,
"num_hidden_layers": 126,
"num_attention_heads": 128,
"num_key_value_heads": 16,
"head_size": 128,
"qk_layernorm": false,
"hidden_act": "silu",
"intermediate_size": 53248,
"norm_epsilon": 1e-05,
"position_embedding_type": "rope_gpt_neox",
"use_parallel_embedding": false,
"embedding_sharding_dim": 0,
"share_embedding_table": false,
"mapping": {
"world_size": 8,
"tp_size": 8,
"pp_size": 1,
"gpus_per_node": 8
},
"quantization": {
"quant_algo": "FP8",
"kv_cache_quant_algo": null,
"group_size": 128,
"smoothquant_val": null,
"has_zero_point": false,
"pre_quant_scale": false,
"exclude_modules": [
"lm_head"
]
},
"kv_dtype": "float16",
"rotary_scaling": null,
"residual_mlp": false,
"moe_normalization_mode": null,
"rotary_base": 500000.0,
"moe_num_experts": 0,
"moe_top_k": 0,
"moe_tp_mode": 2,
"attn_bias": false,
"disable_weight_only_quant_plugin": false,
"mlp_bias": false
}
```
### used for vLLM and SGLang
```
{
"_name_or_path": "dummy_fp8",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128009,
"hidden_act": "silu",
"hidden_size": 16384,
"initializer_range": 0.02,
"intermediate_size": 53248,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 128,
"num_hidden_layers": 126,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"quantization_config": {
"activation_scheme": "static",
"ignored_layers": [
"lm_head"
],
"quant_method": "fp8"
},
"rope_scaling": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"max_position_embeddings": 131072,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.41.1",
"use_cache": true,
"vocab_size": 128256
}
```

19
benchmark/boolq/README.md Normal file
View File

@@ -0,0 +1,19 @@
## Download data
```
git clone https://hf-mirror.com/datasets/google/boolq
```
## Convert parquet to json
```
bash parquet_to_json.sh
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path ramblingpolymath/Qwen3-32B-W8A8 --port 30000
```
```
python3 bench_sglang.py
```

View File

@@ -0,0 +1,124 @@
import argparse
import json
import time
import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import read_jsonl
def get_example(lines, i, answer):
prompt = "Question: " + lines[i]["question"] + lines[i]["passage"] + "\nAnswer:"
if answer:
prompt += str(lines[i]["answer"])
return prompt
def few_shot_examples(lines, k):
prompts = ""
for i in range(k):
prompts += get_example(lines, i, True) + "\n\n"
return prompts
def main(args):
# Select backend
set_default_backend(select_sglang_backend(args))
# Read data
train_data_path = args.train_data_path
test_data_path = args.test_data_path
lines_train = list(read_jsonl(train_data_path))
lines_test = list(read_jsonl(test_data_path))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shots = few_shot_examples(lines_train, num_shots)
questions = []
answer = []
for i in range(len(lines_test[:num_questions])):
questions.append(get_example(lines_test, i, False))
answer.append(str(lines_test[i]["answer"]))
arguments = [{"question": q} for q in questions]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_boolq(s, question):
s += few_shots + question
s += sgl.gen("answer", max_tokens=5, stop=["\n"])
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.perf_counter()
states = few_shot_boolq.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
preds = []
for i in range(len(states)):
preds.append(states[i]["answer"])
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(answer))
# Compute speed
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Results
with open(args.result_file, "a") as fout:
value = {
"task": "boolq",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument(
"--train-data-path", type=str, default="./boolq/data/train-00000-of-00001.json"
)
parser.add_argument(
"--test-data-path",
type=str,
default="./boolq/data/validation-00000-of-00001.json",
)
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,28 @@
import sys
import pyarrow.parquet as pq
def convert_parquet_to_json(input_file, output_file):
# read parquet file
table = pq.read_table(input_file)
# turn parquet data to dataframe
df = table.to_pandas()
# turn dataframe to json form
json_data = df.to_json(orient="records", lines=True)
# write json to file
with open(output_file, "w") as f:
f.write(json_data)
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage:python convert_parquet_to_json.py <input_file> <output_file>")
input_file = sys.argv[1]
output_file = sys.argv[2]
convert_parquet_to_json(input_file, output_file)

View File

@@ -0,0 +1,26 @@
#!/bin/bash
#define input and output direction
input_dir="./boolq/data"
output_dir="./boolq/data"
#define files needed to be handled
files=(
"train-00000-of-00001.parquet"
"validation-00000-of-00001.parquet"
)
#foe files above, use python script to convert the form
for file in "${files[@]}"; do
input_file="${input_dir}/${file}"
output_file="${output_dir}/${file%.parquet}.json"
echo "Converting ${input_file} to ${output_file} ..."
python3 convert_parquet_to_json.py "${input_file}" "${output_file}"
if [ $? -eq 0 ]; then
echo "Conversion successful: ${output_file}"
else
echo "Conversion failed: ${input_file}"
fi
done

15
benchmark/ceval/README.md Normal file
View File

@@ -0,0 +1,15 @@
## Download data
```
git lfs clone https://huggingface.co/datasets/ceval/ceval-exam
```
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path ramblingpolymath/Qwen3-32B-W8A8 --port 30000
```
```
python3 bench_sglang.py
```

View File

@@ -0,0 +1,138 @@
import argparse
import json
import os
import random
import re
import time
import numpy as np
from datasets import load_dataset
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
choices = ["A", "B", "C", "D"]
def get_one_example(line, include_answer):
res = line["question"]
res += f"\nA. {line['A']}"
res += f"\nB. {line['B']}"
res += f"\nC. {line['C']}"
res += f"\nD. {line['D']}"
if include_answer:
res += f"\nAnswer: {line['answer']} \n\n"
return res
def get_few_shot_examples(lines):
res = ""
for line in lines:
res += get_one_example(line, True) + "\n\n"
return res
def get_answer_value(response):
pattern = r"(Answer:|answer:|答案是|答案是:|正确答案是:|答案:|Assistant:)\s*([A-D])(?![\w])"
match = re.search(pattern, response)
if match:
return match.group(2)
return random.choice(choices)
def main(args):
# Read data && Construct prompts
arguments = []
labels = []
examples = "examples:\n"
data_path = args.data_path
for subject in os.listdir(data_path):
subject_path = os.path.join(data_path, subject)
if os.path.isdir(subject_path) and subject != ".git":
dataset = load_dataset(data_path, name=subject)
dev_lines_temp = dataset["dev"]
val_lines_temp = dataset["val"]
few_shot_examples = get_few_shot_examples(dev_lines_temp, subject)
examples += f"{few_shot_examples}"
for val_line in val_lines_temp:
arguments.append(
{
"examples": few_shot_examples,
"question": get_one_example(val_line, False),
}
)
labels.append(val_line["answer"])
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_ceval(s, examples, question):
s += examples + question + sgl.gen("Answer")
#####################################
########## SGL Program End ##########
#####################################
num_questions = args.num_questions if args.num_questions else len(arguments)
# Select backend
set_default_backend(select_sglang_backend(args))
# Run requests
tic = time.perf_counter()
states = few_shot_ceval.run_batch(
arguments[:num_questions],
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
preds = [get_answer_value(states[i]["Answer"]) for i in range(num_questions)]
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels[:num_questions]))
# Compute speed
num_output_tokens = sum(
s.get_meta_info("Answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "ceval",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="ceval-exam")
parser.add_argument("--num-questions", type=int, default=None)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,373 @@
# DeepSeek V3.1/V3/R1 Support
The SGLang and DeepSeek teams collaborated to get DeepSeek V3 FP8 running on NVIDIA and AMD GPUs **from day one**. SGLang also supports [MLA optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [DP attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models), making SGLang one of the best open-source LLM engines for running DeepSeek models. SGLang is the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended).
Special thanks to Meituan's Search & Recommend Platform Team and Baseten's Model Performance Team for implementing the model, and DataCrunch for providing GPU resources.
For optimizations made on the DeepSeek series models regarding SGLang, please refer to [DeepSeek Model Optimizations in SGLang](https://docs.sglang.ai/basic_usage/deepseek.html).
## Installation & Launch
If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded.
### Using Docker (Recommended)
```bash
# Pull latest image
# https://hub.docker.com/r/lmsysorg/sglang/tags
docker pull lmsysorg/sglang:latest
# Launch
docker run --gpus all --shm-size 32g -p 30000:30000 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host --network=host --privileged lmsysorg/sglang:latest \
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code --port 30000
```
If you are using RDMA, please note that:
1. `--network host` and `--privileged` are required by RDMA. If you don't need RDMA, you can remove them.
2. You may need to set `NCCL_IB_GID_INDEX` if you are using RoCE, for example: `export NCCL_IB_GID_INDEX=3`.
Add [performance optimization options](#performance-optimization-options) as needed.
### Using pip
```bash
# Installation
pip install "sglang[all]>=0.5.2rc1"
# Launch
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
```
Add [performance optimization options](#performance-optimization-options) as needed.
<a id="option_args"></a>
### Performance Optimization Options
[MLA optimizations](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) are enabled by default. Here are some optional optimizations can be enabled as needed.
- [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models): For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput.
- [Torch.compile Optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#torchcompile-latency-optimizations): Add `--enable-torch-compile` argument to enable it. This will take some time while server starts. The maximum batch size for torch.compile optimization can be controlled with `--torch-compile-max-bs`. It's recommended to set it between `1` and `8`. (e.g., `--torch-compile-max-bs 8`)
### Usage: Chat with DeepSeek
#### DeepSeek V3/R1
```python3
import openai
client = openai.Client(
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response)
```
#### DeepSeek V3.1
On top of the basic usage similar to the DeepSeek V3/R1 example, DeepSeek V3.1 supports a request-level thinking/non-thinking toggle. Simply switch the `"thinking"` field in `extra_body={"chat_template_kwargs": {"thinking": True}}` to enable/disable the thinking mode.
##### Non Thinking
```python3
import openai
client = openai.Client(
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Answer the following with the second letter of the correct answer only: What is the capital of France?"},
],
temperature=0,
max_tokens=1024,
extra_body = {"chat_template_kwargs": {"thinking": False}}
)
print(response.choices[0].message.content)
```
Answer:
```
h
```
* The correct response should be 'A', as the correct answer to the question is 'Paris'.
##### Thinking
```python3
import openai
client = openai.Client(
base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
# Chat completion
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Answer the following with the second letter of the correct answer only: What is the capital of France?"},
],
temperature=0,
max_tokens=1024,
extra_body = {"chat_template_kwargs": {"thinking": True}}
)
print(response)
```
Answer:
```
First, the question is: "What is the capital of France?" I know that the capital of France is Paris.
The user says: "Answer the following with the second letter of the correct answer only." So, I need to provide only the second letter of the correct answer.
The correct answer is "Paris". Now, I need to find the second letter of "Paris".
Let's spell it out: P-A-R-I-S.
- First letter: P
- Second letter: A
- Third letter: R
- Fourth letter: I
- Fifth letter: S
So, the second letter is "A".
I should only output the second letter, which is "A". No additional text or explanation, just the letter.
The user emphasized "the second letter of the correct answer only", so my response should be just "A".
Finally, I need to make sure that this is the correct answer. Yes, Paris is indeed the capital of France.</think>A
```
* The response contains `</think>` thinking trace and model was able to derive the correct answer from it.
### Example: Serving with two H20\*8 nodes
For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. Please **use the first node's IP** for both commands.
If the command fails, try setting the `GLOO_SOCKET_IFNAME` parameter. For more information, see [Common Environment Variables](https://pytorch.org/docs/stable/distributed.html#common-environment-variables).
If the multi nodes support NVIDIA InfiniBand and encounter hanging issues during startup, consider adding the parameter `export NCCL_IB_GID_INDEX=3`. For more information, see [this](https://github.com/sgl-project/sglang/issues/3516#issuecomment-2668493307).
```bash
# node 1
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code
# node 2
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code
```
If you have two H100 nodes, the usage is similar to the aforementioned H20.
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
### Example: Serving with two H200\*8 nodes and docker
There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`.
A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage.
```bash
# node 1
docker run --gpus all \
--shm-size 32g \
--network=host \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--name sglang_multinode1 \
-it \
--rm \
--env "HF_TOKEN=$HF_TOKEN" \
--ipc=host \
lmsysorg/sglang:latest \
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 40000
```
```bash
# node 2
docker run --gpus all \
--shm-size 32g \
--network=host \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--name sglang_multinode2 \
-it \
--rm \
--env "HF_TOKEN=$HF_TOKEN" \
--ipc=host \
lmsysorg/sglang:latest \
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 1 --trust-remote-code --host 0.0.0.0 --port 40000
```
To ensure functionality, we include a test from a client Docker container.
```bash
docker run --gpus all \
--shm-size 32g \
--network=host \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--name sglang_multinode_client \
-it \
--rm \
--env "HF_TOKEN=$HF_TOKEN" \
--ipc=host \
lmsysorg/sglang:latest \
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 1 --host 0.0.0.0 --port 40000 --output-file "deepseekv3_multinode.jsonl"
```
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
### Example: Serving with four A100\*8 nodes
To serve DeepSeek-V3 with A100 GPUs, we need to convert the [FP8 model checkpoints](https://huggingface.co/deepseek-ai/DeepSeek-V3) to BF16 with [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) mentioned [here](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) first.
Since the BF16 model is over 1.3 TB, we need to prepare four A100 nodes, each with 8 80GB GPUs. Assume the first node's IP is `10.0.0.1`, and the converted model path is `/path/to/DeepSeek-V3-BF16`, we can have following commands to launch the server.
```bash
# node 1
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 30000
# node 2
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 1 --trust-remote-code
# node 3
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 2 --trust-remote-code
# node 4
python3 -m sglang.launch_server --model-path /path/to/DeepSeek-V3-BF16 --tp 32 --dist-init-addr 10.0.0.1:5000 --nnodes 4 --node-rank 3 --trust-remote-code
```
> **Note that the launch command here does not enable Data Parallelism Attention or `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
Then we can benchmark the accuracy and latency by accessing the first node's exposed port with the following example commands.
```bash
# bench accuracy
python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --host http://10.0.0.1 --port 30000
# bench latency
python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1:30000 --batch-size 1 --input-len 128 --output-len 128
```
### Example: Serving with 8 A100/A800 with AWQ Quantization
**Recommended Usage**
Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance.
One example is as follows:
```bash
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16
```
Alternatively, you can use `--quantization awq_marlin` as follows:
```bash
python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization awq_marlin --dtype float16
```
Note that `awq_marlin` only supports `float16` now, which may lead to some precision loss.
### Example: Serving with 16 A100/A800 with int8 Quantization
There are block-wise and per-channel quantization methods, and the quantization parameters have already been uploaded to Huggingface. One example is as follows:
- [meituan/DeepSeek-R1-Block-INT8](https://huggingface.co/meituan/DeepSeek-R1-Block-INT8)
- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)
Assuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-INT8` and port=5000, we can have following commands to launch the server:
```bash
#master
python3 -m sglang.launch_server \
--model meituan/DeepSeek-R1-Block-INT8 --tp 16 --dist-init-addr \
MASTER_IP:5000 --nnodes 2 --node-rank 0 --trust-remote-code --enable-torch-compile --torch-compile-max-bs 8
#cluster
python3 -m sglang.launch_server \
--model meituan/DeepSeek-R1-Block-INT8 --tp 16 --dist-init-addr \
MASTER_IP:5000 --nnodes 2 --node-rank 1 --trust-remote-code --enable-torch-compile --torch-compile-max-bs 8
```
> **Note that the launch command here enables `torch.compile` Optimization**. For optimal performance, please refer to the command options in [Performance Optimization Options](#option_args).
Then on the **master node**, supposing the ShareGPT data is located at `/path/to/ShareGPT_V3_unfiltered_cleaned_split.json`, you can run the following commands to benchmark the launched server:
```bash
# bench accuracy
python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319
# bench serving
python3 -m sglang.bench_serving --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json --dataset-name random --random-input 128 --random-output 128 --num-prompts 1000 --request-rate 128 --random-range-ratio 1.0
```
> **Note: using `--parallel 200` can accelerate accuracy benchmarking**.
### Example: Serving with 32 L40S with int8 Quantization
Running with per-channel quantization model:
- [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8)
Assuming that master node IP is `MASTER_IP`, checkpoint path is `/path/to/DeepSeek-R1-Channel-INT8` and port=5000, we can have following commands to launch the server:
```bash
#master
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 0 --trust-remote \
--enable-torch-compile --torch-compile-max-bs 32
#cluster
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 1 --trust-remote \
--enable-torch-compile --torch-compile-max-bs 32
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 2 --trust-remote \
--enable-torch-compile --torch-compile-max-bs 32
python3 -m sglang.launch_server --model meituan/DeepSeek-R1-Channel-INT8 --tp 32 --quantization w8a8_int8 \
--dist-init-addr MASTER_IP:5000 --nnodes 4 --node-rank 3 --trust-remote \
--enable-torch-compile --torch-compile-max-bs 32
```
The benchmarking method is the same as describted in the previous [16 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization) example.
### Example: Serving on any cloud or Kubernetes with SkyPilot
SkyPilot helps find cheapest available GPUs across any cloud or existing Kubernetes clusters and launch distributed serving with a single command. See details [here](https://github.com/skypilot-org/skypilot/tree/master/llm/deepseek-r1).
To serve on multiple nodes:
```bash
git clone https://github.com/skypilot-org/skypilot.git
# Serve on 2 H100/H200x8 nodes
sky launch -c r1 llm/deepseek-r1/deepseek-r1-671B.yaml --retry-until-up
# Serve on 4 A100x8 nodes
sky launch -c r1 llm/deepseek-r1/deepseek-r1-671B-A100.yaml --retry-until-up
```
#### Troubleshooting
If you encounter the following error with fp16/bf16 checkpoint:
```bash
ValueError: Weight output_partition_size = 576 is not divisible by weight quantization block_n = 128.
```
edit your `config.json` and remove the `quantization_config` block. For example:
```json
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128]
},
```
Removing this block typically resolves the error. For more details, see the discussion in [sgl-project/sglang#3491](https://github.com/sgl-project/sglang/issues/3491#issuecomment-2650779851).
## DeepSeek V3 Optimization Plan
https://github.com/sgl-project/sglang/issues/2591

51
benchmark/dspy/README.md Normal file
View File

@@ -0,0 +1,51 @@
## Install
```
pip3 install dspy-ai
```
Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/dsp/modules/cache_utils.py#L10.
```
cache_turn_on = False
```
or set the environment variable
```
export DSP_CACHEBOOL=false
```
## Benchmark SGLang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_dspy_intro.py --backend sglang
```
## Benchmark TGI
```
docker run --name tgi --rm -ti --gpus all --network host \
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
ghcr.io/huggingface/text-generation-inference:1.3.0 \
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
--max-input-length 2048 --max-total-tokens 4096 \
--port 24000
```
```
python3 bench_dspy_intro.py --backend tgi
```
## Benchmark vLLM
```
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_dspy_intro.py --backend vllm
```

View File

@@ -0,0 +1,192 @@
"""
Adapted from
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
"""
import argparse
import dspy
from dspy.datasets import HotPotQA
class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")
class GenerateAnswer(dspy.Signature):
"""Answer questions with short factoid answers."""
context = dspy.InputField(desc="may contain relevant facts")
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")
class RAG(dspy.Module):
def __init__(self, num_passages=3):
super().__init__()
self.retrieve = dspy.Retrieve(k=num_passages)
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
def forward(self, question):
context = self.retrieve(question).passages
prediction = self.generate_answer(context=context, question=question)
return dspy.Prediction(context=context, answer=prediction.answer)
def main(args):
# lm = dspy.OpenAI(model='gpt-3.5-turbo')
if args.backend == "tgi":
lm = dspy.HFClientTGI(
model="meta-llama/Llama-2-7b-chat-hf",
port=args.port,
url="http://localhost",
)
elif args.backend == "sglang":
lm = dspy.HFClientSGLang(
model="meta-llama/Llama-2-7b-chat-hf",
port=args.port,
url="http://localhost",
)
elif args.backend == "vllm":
lm = dspy.HFClientVLLM(
model="meta-llama/Llama-2-7b-chat-hf",
port=args.port,
url="http://localhost",
)
else:
raise ValueError(f"Invalid backend: {args.backend}")
colbertv2_wiki17_abstracts = dspy.ColBERTv2(
url="http://20.102.90.50:2017/wiki17_abstracts"
)
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
# Load the dataset.
dataset = HotPotQA(
train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0
)
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
trainset = [x.with_inputs("question") for x in dataset.train]
devset = [x.with_inputs("question") for x in dataset.dev]
print(len(trainset), len(devset))
train_example = trainset[0]
print(f"Question: {train_example.question}")
print(f"Answer: {train_example.answer}")
dev_example = devset[18]
print(f"Question: {dev_example.question}")
print(f"Answer: {dev_example.answer}")
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
print(
f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}"
)
print(
f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}"
)
# Define the predictor.
generate_answer = dspy.Predict(BasicQA)
# Call the predictor on a particular input.
pred = generate_answer(question=dev_example.question)
# Print the input and the prediction.
print(f"Question: {dev_example.question}")
print(f"Predicted Answer: {pred.answer}")
lm.inspect_history(n=1)
# Define the predictor. Notice we're just changing the class. The signature BasicQA is unchanged.
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
# Call the predictor on the same input.
pred = generate_answer_with_chain_of_thought(question=dev_example.question)
# Print the input, the chain of thought, and the prediction.
print(f"Question: {dev_example.question}")
print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
print(f"Predicted Answer: {pred.answer}")
retrieve = dspy.Retrieve(k=3)
topK_passages = retrieve(dev_example.question).passages
print(
f"Top {retrieve.k} passages for question: {dev_example.question} \n",
"-" * 30,
"\n",
)
for idx, passage in enumerate(topK_passages):
print(f"{idx+1}]", passage, "\n")
retrieve("When was the first FIFA World Cup held?").passages[0]
from dspy.teleprompt import BootstrapFewShot
# Validation logic: check that the predicted answer is correct.
# Also check that the retrieved context does actually contain that answer.
def validate_context_and_answer(example, pred, trace=None):
answer_EM = dspy.evaluate.answer_exact_match(example, pred)
answer_PM = dspy.evaluate.answer_passage_match(example, pred)
return answer_EM and answer_PM
# Set up a basic teleprompter, which will compile our RAG program.
teleprompter = BootstrapFewShot(metric=validate_context_and_answer)
# Compile!
compiled_rag = teleprompter.compile(RAG(), trainset=trainset)
# Ask any question you like to this simple RAG program.
my_question = "What castle did David Gregory inherit?"
# Get the prediction. This contains `pred.context` and `pred.answer`.
pred = compiled_rag(my_question)
# Print the contexts and the answer.
print(f"Question: {my_question}")
print(f"Predicted Answer: {pred.answer}")
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
from dspy.evaluate.evaluate import Evaluate
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
evaluate_on_hotpotqa = Evaluate(
devset=devset,
num_threads=args.num_threads,
display_progress=True,
display_table=5,
)
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
metric = dspy.evaluate.answer_exact_match
evaluate_on_hotpotqa(compiled_rag, metric=metric)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int)
parser.add_argument("--num-threads", type=int, default=32)
parser.add_argument("--dev-size", type=int, default=150)
parser.add_argument(
"--backend", type=str, choices=["sglang", "tgi", "vllm"], default="sglang"
)
args = parser.parse_args()
if args.port is None:
default_port = {
"vllm": 21000,
"lightllm": 22000,
"tgi": 24000,
"sglang": 30000,
}
args.port = default_port.get(args.backend, None)
main(args)

View File

@@ -0,0 +1,38 @@
## Download the dataset
```
wget -O agent_calls.jsonl https://drive.google.com/uc?export=download&id=19qLpD45e9JGTKF2cUjJJegwzSUEZEKht
```
## Run benchmark
Ensure that this benchmark is run in a serial manner (using --parallel 1) to preserve any potential dependencies between requests.
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-events 1000 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-events 1000 --backend vllm --parallel 1
```
### Benchmark guidance
```
python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --num-events 1000 --backend lmql --parallel 1
```

View File

@@ -0,0 +1,300 @@
import sglang as sgl
# here are the top five agent functions contributing ~70% LLM calls
# reference: https://github.com/joonspk-research/generative_agents/
@sgl.function
def poignancy_event(s, persona_name, persona_iss, event):
s += "Here is a brief description of " + persona_name + ".\n"
s += persona_iss + "\n"
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
s += persona_name + ".\n\n"
s += "Event: " + event
s += "Rate (return a number between 1 to 10):"
s += sgl.gen(name="Rate", max_tokens=2)
def poignancy_event_prompt(persona_name, persona_iss, event):
# return prompt and max_tokens
s = ""
s += "Here is a brief description of " + persona_name + ".\n"
s += persona_iss + "\n"
s += "On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following event for"
s += persona_name + ".\n\n"
s += "Event: " + event
s += "Rate (return a number between 1 to 10):"
return {"prompt": s, "max_tokens": 2, "stop": None}
@sgl.function
def generate_event_triple(s, persona_name, action):
s += """Task: Turn the input into (subject, predicate, object).
Input: Sam Johnson is eating breakfast.
Output: (Dolores Murphy, eat, breakfast)
---
Input: Joon Park is brewing coffee.
Output: (Joon Park, brew, coffee)
---
Input: Jane Cook is sleeping.
Output: (Jane Cook, is, sleep)
---
Input: Michael Bernstein is writing email on a computer.
Output: (Michael Bernstein, write, email)
---
Input: Percy Liang is teaching students in a classroom.
Output: (Percy Liang, teach, students)
---
Input: Merrie Morris is running on a treadmill.
Output: (Merrie Morris, run, treadmill)
---"""
s += persona_name + "is" + action + ".\n"
s += "(" + persona_name + ","
s += sgl.gen(name="Triple", max_tokens=20, stop=")")
def generate_event_triple_prompt(persona_name, action):
s = ""
s += """Task: Turn the input into (subject, predicate, object).
Input: Sam Johnson is eating breakfast.
Output: (Dolores Murphy, eat, breakfast)
---
Input: Joon Park is brewing coffee.
Output: (Joon Park, brew, coffee)
---
Input: Jane Cook is sleeping.
Output: (Jane Cook, is, sleep)
---
Input: Michael Bernstein is writing email on a computer.
Output: (Michael Bernstein, write, email)
---
Input: Percy Liang is teaching students in a classroom.
Output: (Percy Liang, teach, students)
---
Input: Merrie Morris is running on a treadmill.
Output: (Merrie Morris, run, treadmill)
---"""
s += persona_name + "is" + action + ".\n"
s += "(" + persona_name + ","
return {"prompt": s, "max_tokens": 20, "stop": ")"}
@sgl.function
def generate_pronunciatio(s, action):
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
s += "Action description: " + action + ".\n"
s += "Emoji:" + sgl.gen(name="Emoji", max_tokens=6)
def generate_pronunciatio_prompt(action):
s = ""
s += "Convert an action description to an emoji (important: use two or less emojis).\n"
s += "Action description: " + action + ".\n"
s += "Emoji:"
return {"prompt": s, "max_tokens": 6, "stop": None}
@sgl.function
def action_location_sector(
s,
persona_name,
living_sector,
living_sector_areas,
current_sector,
current_sector_areas,
daily_plan,
sector_options,
current_action,
next_action,
):
s += """Task -- choose an appropriate area from the area options for a task at hand.
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
---
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
---"""
s += (
persona_name
+ " lives in "
+ living_sector
+ " that has "
+ living_sector_areas
+ ".\n"
)
s += (
persona_name
+ " is currently in "
+ current_sector
+ " that has "
+ current_sector_areas
+ ".\n"
)
s += daily_plan + ".\n"
s += "Area options: " + sector_options + ".\n"
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.\n"""
s += (
persona_name
+ " is "
+ current_action
+ ". For "
+ next_action
+ ", "
+ persona_name
+ " should go to the following area: {"
)
s += sgl.gen(name="Location", max_tokens=10, stop="}")
def action_location_sector_prompt(
persona_name,
living_sector,
living_sector_areas,
current_sector,
current_sector_areas,
daily_plan,
sector_options,
current_action,
next_action,
):
s = ""
s += """Task -- choose an appropriate area from the area options for a task at hand.
Sam Kim lives in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
Sam Kim is currently in {Sam Kim's house} that has Sam Kim's room, bathroom, kitchen.
Area options: {Sam Kim's house, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.
For taking a walk, Sam Kim should go to the following area: {Johnson Park}
---
Jane Anderson lives in {Oak Hill College Student Dormatory} that has Jane Anderson's room.
Jane Anderson is currently in {Oak Hill College} that has a classroom, library
Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs Cafe, Oak Hill College, Johnson Park, Harvey Oak Supply Store, The Willows Market and Pharmacy}.
* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
---"""
s += (
persona_name
+ " lives in "
+ living_sector
+ " that has "
+ living_sector_areas
+ ".\n"
)
s += (
persona_name
+ " is currently in "
+ current_sector
+ " that has "
+ current_sector_areas
+ ".\n"
)
s += daily_plan + ".\n"
s += "Area options: " + sector_options + ".\n"
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
* Must be one of the "Area options," verbatim.\n"""
s += (
persona_name
+ " is "
+ current_action
+ ". For "
+ next_action
+ ", "
+ persona_name
+ " should go to the following area: {"
)
return {"prompt": s, "max_tokens": 10, "stop": "}"}
@sgl.function
def action_location_object(
s, persona_name, target_sector, target_sector_areas, current_action, next_action
):
s += """
Jane Anderson is in kitchen in Jane Anderson's house.
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
Answer: {kitchen}
---
Tom Watson is in common room in Tom Watson's apartment.
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
Answer: {cafe}
---"""
s += (
persona_name
+ " is going to "
+ target_sector
+ " that has the following areas: {"
+ target_sector_areas
+ "}\n"
)
s += """* Stay in the current area if the activity can be done there.
* NEVER go into other people's rooms unless necessary."""
s += (
persona_name
+ " is "
+ current_action
+ ". For "
+ next_action
+ ", "
+ persona_name
+ "should go to the following area in "
+ target_sector
)
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
def action_location_object_prompt(
persona_name, target_sector, target_sector_areas, current_action, next_action
):
s = ""
s += """
Jane Anderson is in kitchen in Jane Anderson's house.
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
For cooking, Jane Anderson should go to the following area in Jane Anderson's house:
Answer: {kitchen}
---
Tom Watson is in common room in Tom Watson's apartment.
Tom Watson is going to Hobbs Cafe that has the following areas: {cafe}
Stay in the current area if the activity can be done there. Never go into other people's rooms unless necessary.
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
Answer: {cafe}
---"""
s += (
persona_name
+ " is going to "
+ target_sector
+ " that has the following areas: {"
+ target_sector_areas
+ "}\n"
)
s += """* Stay in the current area if the activity can be done there.
* NEVER go into other people's rooms unless necessary."""
s += (
persona_name
+ " is "
+ current_action
+ ". For "
+ next_action
+ ", "
+ persona_name
+ "should go to the following area in "
+ target_sector
)
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
s += "Answer: {"
return {"prompt": s, "max_tokens": 5, "stop": "}"}

View File

@@ -0,0 +1,80 @@
import argparse
import json
import time
from agent_functions import (
action_location_object_prompt,
action_location_sector_prompt,
generate_event_triple_prompt,
generate_pronunciatio_prompt,
poignancy_event_prompt,
)
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
def main(args):
lines = read_jsonl(args.data_path)[: args.num_events]
mapping = {
"poignancy_event": poignancy_event_prompt,
"generate_event_triple": generate_event_triple_prompt,
"generate_pronunciatio": generate_pronunciatio_prompt,
"action_location_sector": action_location_sector_prompt,
"action_location_object": action_location_object_prompt,
}
arguments = [mapping[k](**v) for l in lines for k, v in l.items()]
states = []
# Select backend
call_generate = get_call_generate(args)
def get_one_answer(arg):
answer = call_generate(**arg, temperature=0)
states.append(answer)
async def get_one_answer_async(arg):
answer = await call_generate(**arg, temperature=0)
states.append(answer)
tic = time.perf_counter()
# we always sequentially execute agent calls to maintain its dependency
if args.backend != "lmql":
for arg in tqdm(arguments):
get_one_answer(arg)
else:
import asyncio
loop = asyncio.get_event_loop()
for arg in tqdm(arguments):
loop.run_until_complete(get_one_answer_async(arg))
latency = time.perf_counter() - tic
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "Generative Agents",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
# to pack weighted functions as a single agent
"num_requests": len(arguments) / len(mapping),
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
parser.add_argument("--num-events", type=int, default=10)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,74 @@
import argparse
import json
import time
from agent_functions import (
action_location_object,
action_location_sector,
generate_event_triple,
generate_pronunciatio,
poignancy_event,
)
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
def main(args):
lines = read_jsonl(args.data_path)[: args.num_events]
mapping = {
"poignancy_event": poignancy_event,
"generate_event_triple": generate_event_triple,
"generate_pronunciatio": generate_pronunciatio,
"action_location_sector": action_location_sector,
"action_location_object": action_location_object,
}
arguments = [{mapping[k]: v for k, v in l.items()} for l in lines]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
states = []
# Run requests
tic = time.perf_counter()
for a in arguments:
# only a single key in the dict
for func, arg in a.items():
result = func.run(**arg)
result.sync()
states.append(result)
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "Generative Agents",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
# to pack weighted functions as a single agent
"num_requests": len(arguments) / len(mapping),
"other": {
"num_events": args.num_events,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="agent_calls.jsonl")
parser.add_argument("--num-events", type=int, default=10)
args = add_common_sglang_args_and_parse(parser)
main(args)

163
benchmark/gpt_oss/README.md Normal file
View File

@@ -0,0 +1,163 @@
# How to reproduce the result of GPT-OSS with SGLang
### Install the latest SGLang
```bash
git clone https://github.com/sgl-project/sglang.git
cd sglang
git checkout v0.5.1.post3
pip install --upgrade pip
pip install -e "python[all]"
```
### Reproduce the benchmark throughput result (Batch Size 1)
Launch Command
```bash
# MXFP4 120B on H100
python3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 8 --attention-backend triton
# BF16 120B on H100
python3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 8 --attention-backend triton
# MXFP4 120B on B200
python3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 4
# BF16 120B on B200
python3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 4
```
Benchmark Command
```bash
# MXFP4 120B on H100
python3 -m sglang.bench_one_batch_server --model openai/gpt-oss-120b --base-url http://localhost:30000 --batch-size 1 --input-len 1024 --output-len 512 --show-report
```
### Reproduce the benchmark throughput result (Batch Size 32)
Launch Command
```bash
# MXFP4 120B on H100
python3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 8
# BF16 120B on H100
python3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 8
# MXFP4 120B on B200
python3 -m sglang.launch_server --model openai/gpt-oss-120b --tp 4
# BF16 120B on B200
python3 -m sglang.launch_server --model lmsys/gpt-oss-120b-bf16 --tp 4
```
Benchmark Command
```bash
python3 -m sglang.bench_one_batch_server --model openai/gpt-oss-120b --base-url http://localhost:30000 --batch-size 32 --input-len 1024 8192 --output-len 512 --show-report
```
### Reproduce the evaluation result
Install gpt-oss
```bash
git clone https://github.com/openai/gpt-oss.git
cd gpt-oss
pip install -e .
```
Evaluation Command
```bash
DATASET=gpqa
BASE_URL=YOUR_BASE_URL
OPENAI_API_KEY=dummy python -m gpt_oss.evals \
--base-url ${BASE_URL}/v1 \
--model dummy \
--reasoning-effort low,medium,high \
--eval $DATASET \
--n-threads 1000
```
### Reproduce the benchmark result of acceptance length
> Note: On B200, if top k is 1, set `--attention-backend trtllm_mha`
```bash
git clone https://github.com/sgl-project/SpecForge.git
cd SpecForge/benchmarks
config_list=(
"1,0,0,0"
"1,3,1,4"
"1,5,4,8"
)
python3 bench_model_speedup.py \
--model-path openai/gpt-oss-120b \
--speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 \
--port 20001 \
--trust-remote-code \
--mem-fraction-static 0.8 \
--tp-size 4 \
--attention-backend fa3 \
--config-list "${config_list[@]}" \
--benchmark-list mtbench:80 gsm8k:200 humaneval:200 math500:200 \
--output lmsys_gpt-oss-120b_Eagle3_result.jsonl
python3 bench_model_speedup.py \
--model-path openai/gpt-oss-120b \
--speculative-draft-model-path nvidia/gpt-oss-120b-Eagle3 \
--port 20001 \
--trust-remote-code \
--mem-fraction-static 0.8 \
--tp-size 4 \
--attention-backend fa3 \
--config-list "${config_list[@]}" \
--benchmark-list mtbench:80 gsm8k:200 humaneval:200 math500:200 \
--output nv_gpt-oss-120b_Eagle3_result.jsonl
```
### Reproduce the result of speculative decoding speedup
Launch Command
```bash
# On Hopper:
# - Tree decoding (topk > 1) and chain decoding (topk = 1) are supported on both FA3 and Triton backends.
python3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algorithm EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tp 4
python3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algorithm EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8 --tp 4
# On Blackwell:
# - Chain decoding (topk = 1) is supported on TRTLLM-MHA backend. Tree decoding (topk > 1) is in progress, stay tuned!
# - Both tree decoding (topk > 1) and chain decoding (topk = 1) are supported on the Triton backend.
python3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algo EAGLE3 --speculative-draft lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tp 4
python3 -m sglang.launch_server --model openai/gpt-oss-120b --speculative-algo EAGLE3 --speculative-draft lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8 --attention-backend triton --tp 4
```
Benchmark Command
```bash
config_list=(
"1,0,0,0"
"1,3,1,4"
"1,5,4,8"
)
python3 bench_model_speedup.py \
--model-path openai/gpt-oss-120b \
--speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 \
--port 20001 \
--trust-remote-code \
--mem-fraction-static 0.8 \
--tp-size 4 \
--attention-backend fa3 \
--config-list "${config_list[@]}" \
--benchmark-list gsm8k:200 humaneval:200 math500:200 \
--output lmsys_gpt-oss-120b_Eagle3_result.jsonl
```
We can gain the best speedup with the following settings:
- **1.39x** speedup with the `--speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4` setting.
- **1.52x** speedup with the `--speculative-num-steps 5 --speculative-eagle-topk 4 --speculative-num-draft-tokens 8` setting.

47
benchmark/gsm8k/README.md Normal file
View File

@@ -0,0 +1,47 @@
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 200
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 200 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 200 --backend lightllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
CUDA_VISIBLE_DEVICES=0,1 lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
```
```
python3 bench_other.py --num-questions 100 --backend lmql --parallel 2
```

View File

@@ -0,0 +1,151 @@
import argparse
import ast
import asyncio
import json
import re
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
def get_one_example(lines, i, include_answer):
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def main(args):
# Select backend
call_generate = get_call_generate(args)
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
states = [None] * len(labels)
# Run requests
if args.backend != "lmql":
# Use thread pool
def get_one_answer(i):
answer = call_generate(
prompt=few_shot_examples + questions[i],
temperature=0,
max_tokens=256,
stop=["Question", "Assistant:", "<|separator|>"],
)
states[i] = answer
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else:
# Use asyncio
async def batched_call(batch_size):
for i in range(0, len(questions), batch_size):
tasks = []
for q in questions[i : i + batch_size]:
tasks.append(
call_generate(
few_shot_examples + q,
temperature=0,
max_tokens=256,
stop="Question",
)
)
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
states[i + j] = rets[j]
tic = time.perf_counter()
asyncio.run(batched_call(batch_size=args.parallel))
latency = time.perf_counter() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,148 @@
import argparse
import ast
import json
import os
import re
import time
import numpy as np
from sglang.lang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
def get_one_example(lines, i, include_answer):
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def main(args):
# Select backend
set_default_backend(select_sglang_backend(args))
# Read data
data_path = args.data_path
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
if not os.path.isfile(data_path):
data_path = download_and_cache_file(url)
lines = list(read_jsonl(data_path))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q} for q in questions]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_gsm8k(s, question):
s += few_shot_examples + question
s += sgl.gen(
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.perf_counter()
states = few_shot_gsm8k.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"]))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
# Compute speed
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)
with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,47 @@
## Run benchmark
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 200
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 200 --backend vllm
```
### Benchmark lightllm
```
# A10G
python -m lightllm.server.api_server --tokenizer_mode auto --model_dir ~/model_weights/llama-2-7b-chat-hf --max_total_token_num 16000 --port 22000
```
```
python3 bench_other.py --num-questions 200 --backend lightllm
```
### Benchmark guidance
```
CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
lmql serve-model meta-llama/Llama-2-7b-chat-hf --cuda --port 23000
```
```
python3 bench_other.py --num-questions 200 --backend lmql --port 23000 --parallel 1
```

View File

@@ -0,0 +1,118 @@
import argparse
import asyncio
import json
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
ret += lines[i]["endings"][lines[i]["label"]]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def main(args):
# Select backend
call_select = get_call_select(args)
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
preds = [None] * len(labels)
# Run requests
if args.backend != "lmql":
# Use thread pool
def get_one_answer(i):
preds[i] = call_select(
context=few_shot_examples + questions[i], choices=choices[i]
)
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else:
# Use asyncio
async def batched_call(batch_size):
for i in range(0, len(questions), batch_size):
tasks = []
for q, c in zip(
questions[i : i + batch_size], choices[i : i + batch_size]
):
tasks.append(call_select(context=few_shot_examples + q, choices=c))
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
preds[i + j] = rets[j]
tic = time.perf_counter()
asyncio.run(batched_call(batch_size=args.parallel))
latency = time.perf_counter() - tic
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
print(f"Latency: {latency:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "hellaswag",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,109 @@
import argparse
import json
import os
import time
import numpy as np
from sglang.lang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
ret += lines[i]["endings"][lines[i]["label"]]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def main(args):
# Select backend
set_default_backend(select_sglang_backend(args))
# Read data
data_path = args.data_path
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
if not os.path.isfile(data_path):
data_path = download_and_cache_file(url)
lines = list(read_jsonl(data_path))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_hellaswag(s, question, choices):
s += few_shot_examples + question
s += sgl.select("answer", choices=choices)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.perf_counter()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
latency = time.perf_counter() - tic
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
print(f"Latency: {latency:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "hellaswag",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)
main(args)

59
benchmark/hf3fs/bench.sh Normal file
View File

@@ -0,0 +1,59 @@
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
python3 benchmark/hf3fs/bench_client.py
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
python3 benchmark/hf3fs/bench_storage.py
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json
echo '{"file_path_prefix": "/data/hf3fs-test-0", "file_size": 1099511627776, "numjobs": 16, "entries": 8}' > \
${SGLANG_HICACHE_HF3FS_CONFIG_PATH}
python3 benchmark/hf3fs/bench_zerocopy.py
####################################################################################################
rm -rf nohup.out && \
nohup python3 -m sglang.launch_server \
--model-path /code/models/Qwen3-32B/ \
--host 0.0.0.0 --port 33301 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 --hicache-size 0 \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs &
rm -rf bench_multiturn.out && \
nohup python3 benchmark/hicache/bench_multiturn.py \
--model-path /code/models/Qwen3-32B \
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
--port 33301 \
--request-length 2048 --num-clients 512 --num-rounds 3 --max-parallel 8 \
> bench_multiturn.out &
####################################################################################################
rm -rf nohup.out && \
nohup python3 -m sglang.launch_server \
--model-path /code/models/DeepSeek-R1/ \
--tp 16 --nnodes 2 --node-rank 0 \
--dist-init-addr 10.74.249.153:5000 \
--host 0.0.0.0 --port 33301 \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2 --hicache-size 60 \
--hicache-write-policy write_through \
--hicache-storage-backend hf3fs &
rm -rf bench_multiturn.out && \
nohup python3 benchmark/hicache/bench_multiturn.py \
--model-path /code/models/Qwen3-32B \
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
--port 33301 \
--request-length 2048 --num-clients 1024 --num-rounds 3 --max-parallel 8 \
> bench_multiturn.out &
####################################################################################################
ps aux | grep "sglang.launch_server" | grep -v grep | awk '{print $2}' | xargs kill -9
ps aux | grep "bench_multiturn.py" | grep -v grep | awk '{print $2}' | xargs kill -9

View File

@@ -0,0 +1,162 @@
import concurrent.futures
import logging
import random
import time
from typing import List
import torch
from tqdm import tqdm
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
def print_stats(x: List[int]):
x = sorted(x)
lenx = len(x)
print(
f"mean = {sum(x)/len(x):.2f}, "
f"min = {min(x):.2f}, "
f"p25 = {x[int(lenx*0.25)]:.2f}, "
f"p50 = {x[int(lenx*0.5)]:.2f}, "
f"p75 = {x[int(lenx*0.75)]:.2f}, "
f"max = {max(x):.2f}"
)
def test():
# /path/to/hf3fs
file_path = "/data/bench.bin"
file_size = 1 << 40
bytes_per_page = 16 << 20
entries = 32
file_ops = Hf3fsClient(file_path, file_size, bytes_per_page, entries)
print("test batch_read / batch_write")
num_pages = 128
dtype = torch.bfloat16
numel = bytes_per_page // dtype.itemsize
offsets = list(range(file_size // bytes_per_page))
random.shuffle(offsets)
offsets = offsets[:num_pages]
offsets = [i * bytes_per_page for i in offsets]
tensor_writes = [
torch.randn(numel, dtype=dtype)
for _ in tqdm(range(num_pages), desc="prepare tensor")
]
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_write"):
results = file_ops.batch_write(
offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries]
)
assert all([result == numel * dtype.itemsize for result in results])
tensor_reads = [
torch.empty(numel, dtype=dtype)
for _ in tqdm(range(num_pages), desc="prepare tensor")
]
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_read"):
results = file_ops.batch_read(
offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries]
)
assert all([result == numel * dtype.itemsize for result in results])
assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)])
file_ops.close()
print("test done")
def bench():
file_path = "/data/bench.bin"
file_size = 1 << 40
bytes_per_page = 16 << 20
entries = 8
numjobs = 16
dtype = torch.bfloat16
numel = bytes_per_page // dtype.itemsize
file_ops = [
Hf3fsClient(file_path, file_size, bytes_per_page, entries)
for _ in range(numjobs)
]
num_page = entries
offsets = list(range(file_size // bytes_per_page))
tensors_write = [torch.randn(numel, dtype=dtype)] * num_page
tensors_read = [torch.empty(numel, dtype=dtype)] * num_page
random.shuffle(offsets)
warmup = 50
iteration = 100
executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs)
w_bw = []
w_size = num_page * numjobs * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
_offsets = [
[
offset * bytes_per_page
for offset in offsets[
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
]
]
for j in range(numjobs)
]
tik = time.perf_counter()
futures = [
executor.submit(file_ops[j].batch_write, offset, tensors_write)
for j, offset in enumerate(_offsets)
]
results = [future.result() for future in futures]
tok = time.perf_counter()
if i < warmup:
continue
w_bw.append(w_size / (tok - tik))
results = [
_result == bytes_per_page for result in results for _result in result
]
assert all(results)
print_stats(w_bw)
r_bw = []
r_size = w_size
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
_offsets = [
[
offset * bytes_per_page
for offset in offsets[
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
]
]
for j in range(numjobs)
]
tik = time.perf_counter()
futures = [
executor.submit(file_ops[j].batch_read, offset, tensors_read)
for j, offset in enumerate(_offsets)
]
results = [future.result() for future in futures]
tok = time.perf_counter()
if i < warmup:
continue
r_bw.append(r_size / (tok - tik))
results = [
_result == bytes_per_page for result in results for _result in result
]
assert all(results)
print_stats(r_bw)
executor.shutdown(wait=True)
for _file_ops in file_ops:
_file_ops.close()
print("bench done")
def main():
logging.basicConfig(level=logging.INFO)
test()
bench()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,258 @@
import json
import logging
import os
import random
import time
from typing import List
import torch
from tqdm import tqdm
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
Hf3fsLocalMetadataClient,
)
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
def print_stats(x: List[int]):
x = sorted(x)
lenx = len(x)
print(
f"mean = {sum(x)/len(x):.2f}, "
f"min = {min(x):.2f}, "
f"p25 = {x[int(lenx*0.25)]:.2f}, "
f"p50 = {x[int(lenx*0.5)]:.2f}, "
f"p75 = {x[int(lenx*0.75)]:.2f}, "
f"max = {max(x):.2f}"
)
def test():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path_prefix = "/data/test"
file_size = 128 << 20
numjobs = 16
bytes_per_page = 16 << 20
entries = 2
dtype = store_dtype
config_path = os.getenv(HiCacheHF3FS.default_env_var)
assert config_path
try:
with open(config_path, "w") as f:
json.dump(
{
"file_path_prefix": file_path_prefix,
"file_size": file_size,
"numjobs": numjobs,
"entries": entries,
},
f,
)
except Exception as e:
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_pages = 10
tensors = {}
for i in range(num_pages):
k = f"key_{i}"
v = torch.randn((numel,)).to(dtype=dtype)
ok = hicache_hf3fs.set(k, v)
if i < (file_size // bytes_per_page):
assert ok, f"Failed to insert {k}"
else:
assert not ok
tensors[k] = v
assert hicache_hf3fs.get("key_8") is None
assert hicache_hf3fs.get("key_9") is None
start = 0
for i in range(start, start + hicache_hf3fs.num_pages):
k = f"key_{i}"
assert hicache_hf3fs.exists(k)
out = hicache_hf3fs.get(k)
assert out is not None
v = tensors[k]
assert torch.allclose(v, out, atol=1e-3), f"Tensor mismatch for {k}"
assert not hicache_hf3fs.exists("not_exists")
hicache_hf3fs.delete("key_7")
v2 = torch.randn((numel,)).to(dtype=dtype)
assert hicache_hf3fs.set("key_new", v2)
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
hicache_hf3fs.clear()
assert (
len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)
== hicache_hf3fs.metadata_client.rank_metadata.num_pages
)
# batch
num_pages = 10
tensors = {}
keys = []
values = []
for i in range(num_pages):
k = f"key_{i}"
keys.append(k)
v = torch.randn((numel,)).to(dtype=dtype)
values.append(v)
ok = hicache_hf3fs.batch_set(keys, values)
assert not ok
assert hicache_hf3fs.get("key_8") is None
assert hicache_hf3fs.get("key_9") is None
results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages])
for result, key, value in zip(
results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages]
):
assert torch.allclose(value, result, atol=1e-3), f"Tensor mismatch for {key}"
hicache_hf3fs.close()
os.remove(hicache_hf3fs.file_path)
print("All test cases passed.")
def bench():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path = "/data/test.bin"
file_size = 1 << 40
numjobs = 16
bytes_per_page = 16 << 20
entries = 8
dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS(
rank=0,
file_path=file_path,
file_size=file_size,
numjobs=numjobs,
bytes_per_page=bytes_per_page,
entries=entries,
dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_page = 128
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
warmup = 50
iteration = 100
w_bw = []
w_size = num_page * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
tik = time.perf_counter()
ok = hicache_hf3fs.batch_set(keys, values)
tok = time.perf_counter()
if i < warmup:
continue
w_bw.append(w_size / (tok - tik))
assert ok
print_stats(w_bw)
r_bw = []
r_size = num_page * bytes_per_page / (1 << 30)
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
num_page,
)
tik = time.perf_counter()
results = hicache_hf3fs.batch_get(keys)
tok = time.perf_counter()
if i < warmup:
continue
r_bw.append(r_size / (tok - tik))
assert all([r is not None for r in results])
print_stats(r_bw)
hicache_hf3fs.close()
def allclose():
# Qwen3-32B
layer_num = 64
head_num, head_dim = 8, 128
kv_lora_rank, qk_rope_head_dim = 0, 0
store_dtype = torch.bfloat16
tokens_per_page = 64
file_path = "/data/test.bin"
file_size = 1 << 40
numjobs = 16
bytes_per_page = 16 << 20
entries = 8
dtype = store_dtype
hicache_hf3fs = HiCacheHF3FS(
rank=0,
file_path=file_path,
file_size=file_size,
numjobs=numjobs,
bytes_per_page=bytes_per_page,
entries=entries,
dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(),
)
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
assert numel * dtype.itemsize == bytes_per_page
num_page = 128
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
iteration = 100
for i in tqdm(range(iteration), desc="Benchmarking write (GB/s)"):
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
ok = hicache_hf3fs.batch_set(keys, values)
assert ok
read_keys, read_results = [], []
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
keys = random.sample(
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
num_page,
)
results = hicache_hf3fs.batch_get(keys)
read_keys.extend(keys)
read_results.extend(results)
assert all([r is not None for r in results])
for key, result in tqdm(zip(read_keys, read_results)):
assert torch.allclose(values[int(key) % num_page], result, atol=1e-3)
hicache_hf3fs.close()
def main():
logging.basicConfig(level=logging.INFO)
test()
bench()
allclose()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,140 @@
import threading
import time
import torch
from tqdm import tqdm
from sglang.srt.distributed import (
get_world_group,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.managers.cache_controller import (
HiCacheController,
PrefetchOperation,
StorageOperation,
)
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
group = get_world_group().cpu_group
max_total_num_tokens = 524288
page_size = 64
kv_cache_dtype = torch.bfloat16
layer_num = 64
head_num, head_dim = 8, 128
device = "cuda"
hicache_ratio = 2
hicache_size = 0
hicache_mem_layout = "page_first"
# hicache_mem_layout = "layer_first"
hicache_write_policy = "write_through"
hicache_io_backend = "kernel"
hicache_storage_backend = "hf3fs"
prefetch_threshold = 256
op_size = 1024
op_num = 16
token_to_kv_pool = MHATokenToKVPool(
max_total_num_tokens,
page_size=page_size,
dtype=kv_cache_dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=layer_num,
device=device,
enable_memory_saver=True,
)
token_to_kv_pool_allocator = TokenToKVPoolAllocator(
max_total_num_tokens,
dtype=kv_cache_dtype,
device=device,
kvcache=token_to_kv_pool,
need_sort=False,
)
kv_cache = token_to_kv_pool_allocator.get_kvcache()
token_to_kv_pool_host = MHATokenToKVPoolHost(
kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
)
load_cache_event = threading.Event()
cache_controller = HiCacheController(
token_to_kv_pool_allocator,
token_to_kv_pool_host,
page_size,
group,
load_cache_event=load_cache_event,
write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
storage_backend=hicache_storage_backend,
prefetch_threshold=prefetch_threshold,
)
operations = [
StorageOperation(
torch.tensor(list(range(i, i + op_size))),
list(range(i, i + op_size)),
hash_value=[f"{j}" for j in range(i, i + op_size, page_size)],
)
for i in tqdm(range(0, op_num * op_size, op_size))
]
tik = time.monotonic()
if hicache_mem_layout == "page_first":
for operation in operations:
cache_controller.zerocopy_page_backup(operation, batch_size=128)
elif hicache_mem_layout == "layer_first":
for operation in operations:
cache_controller.generic_page_backup(operation, batch_size=128)
tok = time.monotonic()
print(f"{tok-tik:.6f} s")
operations = [
PrefetchOperation(
f"{i}",
torch.tensor(list(range(i, i + op_size))),
list(range(i, i + op_size)),
f"{i}",
)
for i in tqdm(range(0, op_num * op_size, op_size))
]
for operation in operations:
operation.hash_value = [
f"{j}"
for j in range(
int(operation.last_hash), int(operation.last_hash) + op_size, page_size
)
]
tik = time.monotonic()
if hicache_mem_layout == "page_first":
for operation in operations:
cache_controller.zerocopy_page_transfer(operation, batch_size=128)
elif hicache_mem_layout == "layer_first":
for operation in operations:
cache_controller.generic_page_transfer(operation, batch_size=128)
tok = time.monotonic()
print(f"{tok-tik:.6f} s")

View File

@@ -0,0 +1,91 @@
## Run synthetic multi-turn benchmark
```
# SGLang server with radix cache disabled
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache
# SGLang server with radix cache on and first-come-first-serve policy
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs
# The default SGLang server with radix cache on and long-prefix-match policy
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000
# SGLang server with hierarchical radix cache enabled
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache
```
```
python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
```
Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.
# Benchmark with more datasets
## Download Dataset
```bash
./download.sh {sharegpt|ultragpt|loogle|nextqa|all}
```
This script will automatically download the required dataset to the current working directory
## Multiturn Benchmark
### Supported Datasets
- sharegpt
- ultrachat
- loogle
### Example Usage:
```bash
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
--port 8001 --enable-multiturn --disable-shuffle
```
This uses `mistralai/Mistral-7B-Instruct-v0.3` model with `sglang` as backend. The dataset
is `longdep_qa.json`. We send `10 conversations` with `10 req/s` to port 8001. We enable
multiturn chat without shuffling the order of conversations (i.e. following the original
order in the dataset file).
### Note:
The requests of multiple conversations are sent in a round robin fashion.
For example, if we have 3 conversations A, B, C whose rounds are `[2, 3, 4]` correspondingly,
multiturn chat will send the requests to the backend in the following order: `[A1, B1, C1, A2, B2, C2, B3, C3, C4]`
This has implications on the cache reuse patterns: the cache reuse distance is the largest
under this request pattern (which means a prefix-aware local scheduler in the backend can
yield the most benefit compared to a FIFO scheduler)
## Shared Prefix Benchmark
### Supported Datasets
- loogle
### Example Usage:
```bash
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
--port 8001 --enable-shared-prefix --disable-shuffle
```
### Note:
Shared Prefix benchmark sends the questions for the same prompt together. For example,
if we have 3 shared prefix A, B, C, which have [2, 3, 4] questions correspondingly,
the shared prefix benchmark will send the requests to the
backend in the following order: `[A+Q1, A+Q2, B+Q1, B+Q2, B+Q3, C+Q1, C+Q2, C+Q3]`.
## Multi Modality Benchmark (WIP)
### Supported Datasets:
- nextqa
### Example Usage:
```bash
Server:
python3 -m sglang.launch_server --model-path lmms-lab/LLaVA-NeXT-Video-7B --tp 2 --dp 1 --port 8001 \
--host 0.0.0.0 --mem-fraction-static 0.9 --tokenizer-path llava-hf/llava-1.5-7b-hf \
--json-model-override-args "{\"architectures\": [\"LlavaVidForCausalLM\"], \"model_type\":\"llava\", \"mm_spatial_pool_stride\":2}"
Client:
python3 bench_serving.py --model lmms-lab/LLaVA-NeXT-Video-7B --backend sglang --dataset-path \
NExTVideo --dataset-name nextqa --request-rate 10 --num-prompts 1 --disable-shuffle --port 8001 \ --enable-multiturn --max-frames 16 --tokenizer llava-hf/llava-1.5-7b-hf --fixed-output-len 2048
```
Note: for the server args, `tokenizer-path`, overriding architecture are necessary.
## Supported Backend
- sglang (oai)
- vllm (oai)
- lmdeploy (oai)

View File

@@ -0,0 +1,101 @@
import json
import queue
import time
import requests
from bench_multiturn import (
ReadyQueue,
WorkloadGenerator,
gen_payload,
log_to_jsonl_file,
parse_args,
)
from tqdm.asyncio import tqdm
from sglang.bench_serving import get_tokenizer
class ContextWorkloadGenerator(WorkloadGenerator):
def __init__(self, args):
# Construct the base URL for requests
self.baseurl = f"http://{args.host}:{args.port}/"
self.url = self.baseurl + "generate"
self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None
self.sent_requests = 0
self.completed_requests = 0
self.dataset = json.load(open(args.dataset_path))
num_requests = min(args.num_clients, len(self.dataset["queries"]))
init_requests = []
for i in range(num_requests):
context_id = self.dataset["queries"][i]["context"]
init_requests.append(
(
i,
gen_payload(
self.dataset["contexts"][context_id]
+ self.dataset["queries"][i]["question"],
len(
self.tokenizer(
self.dataset["queries"][i]["reference_answer"]
)["input_ids"]
),
),
)
)
self.ready_queue = ReadyQueue(init_requests=init_requests)
self.response_queue = queue.Queue()
self.pbar = tqdm(total=num_requests)
self.performance_metrics = {
"ttft": [],
"latency": [],
"itl": [],
"prompt_len": [],
"cached_tokens": [],
"generated_len": [],
}
self.max_parallel = args.max_parallel
self.logfile = args.log_file
def response_handler(self):
while True:
try:
client_id, response = self.response_queue.get(
timeout=10
) # Block until response is available
if not response.success:
raise ValueError(f"Request failed with error: {response.error}")
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["itl"].extend(response.itl)
self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.performance_metrics["generated_len"].append(response.generated_len)
self.completed_requests += 1
except queue.Empty:
if self.pbar.n == self.pbar.total:
break
if __name__ == "__main__":
args = parse_args()
args.num_rounds = 1
args.max_parallel = 24
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in [24, 16, 12, 8, 4, 2, 1]:
args.request_rate = request_rate
requests.post(flush_cache_url)
time.sleep(1)
performance_data = ContextWorkloadGenerator(args).run()
log_to_jsonl_file(performance_data, args.log_file, args.tag)

View File

@@ -0,0 +1,567 @@
import argparse
import asyncio
import json
import logging
import os
import queue
import random
import threading
import time
from dataclasses import dataclass
from functools import wraps
import aiohttp
from sglang.bench_serving import (
RequestFuncOutput,
get_tokenizer,
remove_prefix,
sample_random_requests,
)
# Set up logger
logger = logging.getLogger(__name__)
# Set up JSONL file for debug logging
debug_log_file = None
# Create a lock for thread-safe debug log writing
debug_log_lock = threading.Lock()
def write_debug_log(data):
global debug_log_file
"""Write debug information to a JSONL file"""
if debug_log_file is None:
return
# Acquire lock for thread-safe writing
with debug_log_lock:
# Write as JSONL (JSON Line format)
debug_log_file.write(json.dumps(data) + "\n")
debug_log_file.flush()
def parse_args():
parser = argparse.ArgumentParser(
description="Script to benchmark concurrent requests to a server."
)
parser.add_argument(
"--model-path",
type=str,
default="/data/models/Qwen3-0.6B",
help="model path compatible with Hugging Face Transformers",
)
parser.add_argument(
"--dataset-path",
type=str,
default="/data/models/ShareGPT_V3_unfiltered_cleaned_split/ShareGPT_V3_unfiltered_cleaned_split.json",
help="local dataset to sample tokens from",
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Server hostname or IP (default: localhost)",
)
parser.add_argument(
"--port",
type=int,
default=30000,
help="Server port (default: 30000)",
)
parser.add_argument(
"--duration",
type=int,
default=600,
help="Duration to run the benchmark in seconds (default: 300 seconds)",
)
parser.add_argument(
"--log-level",
type=str,
default="info",
choices=["debug", "info"],
help="Set the logging level (default: info)",
)
parser.add_argument(
"--debug-log-file",
type=str,
default="debug.log.jsonl",
help="File to write debug logs in JSONL format",
)
return parser.parse_args()
def load_config():
config_path = os.getenv("CONFIG_PATH")
if not config_path:
raise ValueError("Environment variable 'CONFIG_PATH' is not set.")
with open(config_path, "r") as f:
config = json.load(f)
required_keys = [
"num_rounds",
"num_clients",
"round_ratios",
"mean_new_tokens_per_round",
"mean_return_tokens_per_round",
"mean_inter_round_interval",
]
for key in required_keys:
if key not in config:
raise KeyError(f"Missing required configuration key: {key}")
num_rounds = config["num_rounds"]
assert len(config["round_ratios"]) == num_rounds
assert len(config["mean_new_tokens_per_round"]) == num_rounds
assert len(config["mean_return_tokens_per_round"]) == num_rounds
assert len(config["mean_inter_round_interval"]) == num_rounds
print(config)
return config
@dataclass
class UserData:
user_id: int
current_round: int
total_rounds: int
prompt: str
return_tokens: int
start: int
def synchronized():
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.lock:
return func(self, *args, **kwargs)
return wrapper
return _decorator
class UserGenerator:
def __init__(self, config, model_path, dataset_path):
self.tokenizer_path = model_path
self.tokenizer = get_tokenizer(self.tokenizer_path)
self.dataset_path = dataset_path
self.user_id = 0
self.lock = threading.Lock()
self.num_rounds = config["num_rounds"]
self.cumulative_ratios = [
sum(config["round_ratios"][: i + 1])
for i in range(len(config["round_ratios"]))
]
self.mean_new_tokens_per_round = config["mean_new_tokens_per_round"]
self.mean_return_tokens_per_round = config["mean_return_tokens_per_round"]
self.mean_inter_round_interval = config["mean_inter_round_interval"]
self.sigma = 100
self.range_ratio = 0.8
assert self.range_ratio <= 1
self.candidate_inputs = [
[
r
for r in sample_random_requests(
input_len=(
self.mean_new_tokens_per_round[i] * (2 - self.range_ratio)
),
output_len=(
self.mean_return_tokens_per_round[i] * (2 - self.range_ratio)
),
num_prompts=config["num_clients"],
range_ratio=self.range_ratio / (2 - self.range_ratio),
tokenizer=self.tokenizer,
dataset_path=self.dataset_path,
random_sample=False,
)
]
for i in range(self.num_rounds)
]
self.multiturn_queue = []
self.user_stats = [0 for _ in range(self.num_rounds)]
self.input_stats = [[0, 0] for _ in range(self.num_rounds)]
self.output_stats = [[0, 0] for _ in range(self.num_rounds)]
def gen(self):
user_id = self.user_id
self.user_id += 1
rand_ratio = random.randint(0, self.cumulative_ratios[-1])
i = len(self.cumulative_ratios)
for idx, cumulative_ratio in enumerate(self.cumulative_ratios):
if rand_ratio >= cumulative_ratio:
continue
else:
i = idx + 1
break
total_rounds = i
current_round = 0
candidate_input = random.sample(self.candidate_inputs[current_round], 1)[0]
self.input_stats[0][0] += candidate_input.prompt_len
self.input_stats[0][1] += 1
prompt = f"{user_id} " + candidate_input.prompt
return_tokens = int(
random.gauss(self.mean_return_tokens_per_round[current_round], self.sigma)
)
if return_tokens <= 0:
return_tokens = self.mean_return_tokens_per_round[current_round]
start = 0
user_data = UserData(
user_id, current_round, total_rounds, prompt, return_tokens, start
)
self.user_stats[total_rounds - 1] += 1
return user_data
@synchronized()
def push(self, user_data, generated_text, len_itl):
self.output_stats[user_data.current_round][0] += len_itl + 1
self.output_stats[user_data.current_round][1] += 1
user_data.current_round += 1
if user_data.current_round >= user_data.total_rounds:
return
candidate_input = random.sample(
self.candidate_inputs[user_data.current_round], 1
)[0]
self.input_stats[user_data.current_round][0] += candidate_input.prompt_len
self.input_stats[user_data.current_round][1] += 1
user_data.prompt += generated_text + candidate_input.prompt
user_data.return_tokens = int(
random.gauss(
self.mean_return_tokens_per_round[user_data.current_round], self.sigma
)
)
if user_data.return_tokens <= 0:
user_data.return_tokens = self.mean_return_tokens_per_round[
user_data.current_round
]
interval = random.gauss(
self.mean_inter_round_interval[user_data.current_round], self.sigma
)
if interval <= 0:
interval = self.mean_inter_round_interval[user_data.current_round]
user_data.start = time.perf_counter() + interval
if len(self.multiturn_queue) == 0:
self.multiturn_queue.append(user_data)
else:
i = len(self.multiturn_queue)
for idx, d in enumerate(self.multiturn_queue):
if user_data.start < d.start:
i = idx
break
self.multiturn_queue.insert(idx, user_data)
@synchronized()
def pop(self):
if (
len(self.multiturn_queue)
and time.perf_counter() > self.multiturn_queue[0].start
):
return self.multiturn_queue.pop(0)
return self.gen()
def gen_payload(prompt, output_len):
payload = {
"text": prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"stream": True,
"stream_options": {"include_usage": True},
"lora_path": "",
"return_logprob": False,
"logprob_start_len": -1,
}
return payload
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
async def async_request_sglang_generate(
user_data,
url,
atomic_counter,
):
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {}
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
output = RequestFuncOutput()
payload = gen_payload(user_data.prompt, user_data.return_tokens)
write_debug_log({"timestamp": st, "user_data": user_data.__dict__})
try:
async with session.post(url=url, json=payload, headers=headers) as response:
if response.status == 200:
prompt_tokens = 0
cached_tokens = 0
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
if data.get("text"):
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
prompt_tokens = (data.get("meta_info") or {}).get(
"prompt_tokens", 0
)
cached_tokens = (data.get("meta_info") or {}).get(
"cached_tokens", 0
)
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text = data["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
output.prompt_len = prompt_tokens
output.cached_tokens = cached_tokens
else:
output.error = response.reason or ""
output.success = False
except Exception as e:
output.success = False
output.error = str(e)
print(f"Request failed: {e}")
atomic_counter.increment(1)
return output
class AtomicCounter:
def __init__(self, initial_value=0):
self._value = initial_value
self.lock = threading.Lock()
@synchronized()
def increment(self, amount=1):
self._value += amount
@synchronized()
def get(self):
return self._value
class WorkloadGenerator:
def __init__(self, args):
config = load_config()
user_generator = UserGenerator(
config,
args.model_path,
args.dataset_path,
)
self.url = f"http://{args.host}:{args.port}/generate"
self.tokenizer = user_generator.tokenizer
self.start_time = None
self.finished_time = None
self.duration = args.duration
self.done = False
self.sent_requests = 0
self.completed_requests = 0
self.user_generator = user_generator
self.response_queue = queue.Queue()
self.performance_metrics = {
"ttft": [],
"latency": [],
"prompt_len": [],
"cached_tokens": [],
}
self.max_parallel = config["num_clients"]
self.atomic_counter = AtomicCounter()
async def handle_request(self, user_data):
try:
response = await async_request_sglang_generate(
user_data, self.url, self.atomic_counter
)
self.response_queue.put((user_data, response))
except Exception as e:
print(f"Request failed: {e}")
self.completed_requests += 1
def request_sender(self):
async def request_loop():
while True:
if self.sent_requests - self.completed_requests < self.max_parallel:
new_request = self.user_generator.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
self.sent_requests += 1
else:
await asyncio.sleep(0.05)
continue
if time.perf_counter() - self.start_time > self.duration:
self.done = True
break
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(request_loop())
loop.close()
def response_handler(self):
while True:
try:
user_data, response = self.response_queue.get(timeout=10)
logger.info(
f"{((time.perf_counter()-self.start_time)/self.duration*100):.2f}%"
)
if not response.success:
raise ValueError(f"Request failed with error: {response.error}")
self.user_generator.push(
user_data, response.generated_text, len(response.itl)
)
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.completed_requests += 1
self.finished_time = time.perf_counter()
except queue.Empty:
if self.done:
break
except ValueError as e:
print(f"Error processing response for client {user_data}: {e}")
continue
def run(self):
request_thread = threading.Thread(target=self.request_sender, daemon=True)
response_thread = threading.Thread(target=self.response_handler, daemon=True)
self.start_time = time.perf_counter()
request_thread.start()
response_thread.start()
request_thread.join()
response_thread.join()
performance_data = {
"summary": {
"total_requests": len(self.performance_metrics["ttft"]),
"average_ttft": sum(self.performance_metrics["ttft"])
/ len(self.performance_metrics["ttft"]),
"p90_ttft": sorted(self.performance_metrics["ttft"])[
int(0.9 * len(self.performance_metrics["ttft"]))
],
"median_ttft": sorted(self.performance_metrics["ttft"])[
len(self.performance_metrics["ttft"]) // 2
],
"average_latency": sum(self.performance_metrics["latency"])
/ len(self.performance_metrics["latency"]),
"p90_latency": sorted(self.performance_metrics["latency"])[
int(0.9 * len(self.performance_metrics["latency"]))
],
"median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2
],
"throughput": self.atomic_counter.get()
/ (self.finished_time - self.start_time),
"cache_hit_rate": (
0
if sum(self.performance_metrics["prompt_len"]) == 0
else sum(self.performance_metrics["cached_tokens"])
/ sum(self.performance_metrics["prompt_len"])
),
},
}
print("All requests completed")
print("Performance metrics summary:")
print(f" Total requests: {performance_data['summary']['total_requests']}")
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
print(
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
)
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print(
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
user_stats = self.user_generator.user_stats
input_stats = self.user_generator.input_stats
output_stats = self.user_generator.output_stats
print(f"round_ratios: {user_stats}")
print(
f"mean_new_tokens_per_round: {[int(a/b) if b > 0 else 0 for a, b in input_stats]}"
)
print(
f"mean_return_tokens_per_round: {[int(a/b) if b > 0 else 0 for a, b in output_stats]}"
)
return performance_data
def main():
global debug_log_file
args = parse_args()
if args.log_level == "debug":
logging.basicConfig(level=logging.DEBUG)
logger.info("use log_level debug")
# Initialize debug log file
debug_log_file = open(args.debug_log_file, "w")
else:
logging.basicConfig(level=logging.INFO)
logger.info("use log_level info")
performance_data = WorkloadGenerator(args).run()
# Close debug log file if it was opened
if debug_log_file:
debug_log_file.close()
if __name__ == "__main__":
main()

42
benchmark/hicache/bench_mix.sh Executable file
View File

@@ -0,0 +1,42 @@
#!/bin/bash
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
rm -rf nohup.out && \
nohup python3 -m sglang.launch_server \
--attention-backend triton \
--model-path /code/models/Qwen3-32B/ \
--log-level info \
--tp 4 --mem-frac 0.25 \
--host 0.0.0.0 --port 33301 \
--enable-metrics --enable-cache-report \
--page-size 64 \
--enable-hierarchical-cache \
--hicache-ratio 2.5 --hicache-size 0 \
--hicache-io-backend kernel \
--hicache-mem-layout layer_first \
--hicache-write-policy write_through \
&
##################################################
export CONFIG_PATH=/tmp/bench_mix_config.json
# num_clients: Maximum number of concurrent client requests to be simulated
# round_ratios: Distribution of requests across rounds. Given sum(round_ratios) total requests,
# round_ratios[i] denotes the number of requests that will execute for (i+1) rounds
echo '{
"num_rounds": 10,
"num_clients": 60,
"round_ratios": [50, 25, 15, 15, 10, 10, 9, 8, 7, 6],
"mean_new_tokens_per_round": [1000, 400, 350, 300, 280, 260, 240, 220, 210, 200],
"mean_return_tokens_per_round": [100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
"mean_inter_round_interval": [30, 30, 30, 30, 30, 30, 30, 30, 30, 30]
}' > ${CONFIG_PATH}
rm -rf bench_mix.out && \
nohup python3 /sgl-workspace/sglang/benchmark/hicache/bench_mix.py \
--model-path /code/models/Qwen3-32B/ \
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
--port 33301 \
--duration 600 \
> bench_mix.out &

View File

@@ -0,0 +1,505 @@
import argparse
import asyncio
import json
import queue
import random
import threading
import time
from datetime import datetime
from typing import Optional
import aiohttp
import numpy as np
import requests
from tqdm.asyncio import tqdm
from sglang.bench_serving import (
RequestFuncOutput,
get_tokenizer,
remove_prefix,
sample_random_requests,
)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
def parse_args():
parser = argparse.ArgumentParser(
description="Script to benchmark concurrent requests to a server."
)
parser.add_argument(
"--num-clients",
type=int,
default=256,
help="Number of concurrent clients",
)
parser.add_argument(
"--max-parallel",
type=int,
default=128,
help="Maximum number of parallel requests",
)
parser.add_argument(
"--request-length",
type=int,
default=512,
help="Length of each new request",
)
parser.add_argument(
"--output-length",
type=int,
default=64,
help="Length of each output",
)
parser.add_argument(
"--num-rounds",
type=int,
default=5,
help="Number of rounds per client",
)
parser.add_argument(
"--distribution",
type=str,
default="poisson",
choices=["poisson", "uniform"],
help="Distribution type for request intervals (poisson or uniform)",
)
parser.add_argument(
"--request-rate",
type=float,
default=1.0,
help="Average number of requests per second",
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Server hostname or IP (default: localhost)",
)
parser.add_argument(
"--port",
type=int,
default=30000,
help="Server port (default: 30000)",
)
parser.add_argument(
"--model-path",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="model path compatible with Hugging Face Transformers",
)
parser.add_argument(
"--dataset-path",
type=str,
default="",
help="local dataset to sample tokens from",
)
parser.add_argument(
"--log-file",
type=str,
default="performance_metrics.jsonl",
help="File to log performance metrics",
)
parser.add_argument(
"--disable-auto-run",
action="store_true",
help="If set, disable automatically testing with a range of request rates.",
)
parser.add_argument(
"--disable-random-sample",
action="store_true",
help="If set, disable random sampling of requests from the ShareGPT dataset.",
)
parser.add_argument(
"--sub-question-input-length",
type=int,
default=0,
help="Length of the sub question input for each request, if set 0 use request_length",
)
parser.add_argument(
"--ready-queue-policy",
type=str,
default="random",
help="Policy for popping requests from the ready queue (random or fifo)",
)
parser.add_argument(
"--tag",
type=str,
default="",
help="Tag of a certain run in the log file",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
return parser.parse_args()
async def async_request_sglang_generate(
payload,
url,
pbar: Optional[tqdm] = None,
):
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {}
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
output = RequestFuncOutput()
try:
async with session.post(url=url, json=payload, headers=headers) as response:
if response.status == 200:
prompt_tokens = 0
cached_tokens = 0
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
if data["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
prompt_tokens = (data.get("meta_info") or {}).get(
"prompt_tokens", 0
)
cached_tokens = (data.get("meta_info") or {}).get(
"cached_tokens", 0
)
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text = data["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
output.prompt_len = prompt_tokens
output.cached_tokens = cached_tokens
output.generated_len = len(output.itl) + 1
else:
output.error = response.reason or ""
output.success = False
except Exception as e:
output.success = False
output.error = str(e)
print(f"Request failed: {e}")
if pbar:
pbar.update(1)
return output
def gen_payload(prompt, output_len):
payload = {
"text": prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"stream": True,
"stream_options": {"include_usage": True},
"lora_path": "",
"return_logprob": False,
"logprob_start_len": -1,
}
return payload
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl", tag=""):
"""Append the data with a timestamp and tag to the specified JSONL file."""
timestamped_data = {"timestamp": datetime.now().isoformat(), "tag": tag, **data}
try:
with open(file_path, "a") as file:
file.write(
json.dumps(timestamped_data) + "\n"
) # Write as a single line in JSONL format
except IOError as e:
print(f"Error writing to JSONL file: {e}")
class ReadyQueue:
"""
Thread-safe queue that can pop requests in different orders based on given policy.
"""
def __init__(self, init_requests=None, policy="random"):
self.lock = threading.Lock()
self.requests = init_requests or []
self.policy = policy
def append(self, item):
with self.lock:
self.requests.append(item)
def pop(self):
with self.lock:
if not self.requests:
return None
if self.policy == "random":
index = random.randrange(len(self.requests))
return self.requests.pop(index)
elif self.policy == "fifo":
return self.requests.pop(0)
else:
# todo, varying thinking time of clients
raise ValueError(f"{self.policy} not implemented")
class WorkloadGenerator:
def __init__(self, args):
# Construct the base URL for requests
self.url = f"http://{args.host}:{args.port}/generate"
self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None
self.sent_requests = 0
self.completed_requests = 0
self.candidate_inputs = sample_random_requests(
input_len=args.request_length,
output_len=args.output_length,
num_prompts=args.num_clients,
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
)
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
if args.sub_question_input_length != 0:
sub_question_input_length = args.sub_question_input_length
else:
sub_question_input_length = args.request_length
self.sub_question_inputs = sample_random_requests(
input_len=sub_question_input_length,
output_len=args.output_length,
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
)
init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length))
for i in range(args.num_clients)
]
self.client_records = {
i: {"round": 0, "history": init_requests[i][1]["text"]}
for i in range(args.num_clients)
}
self.ready_queue = ReadyQueue(
init_requests=init_requests, policy=args.ready_queue_policy
)
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
self.performance_metrics = {
"ttft": [],
"latency": [],
"prompt_len": [],
"cached_tokens": [],
"generated_len": [],
}
self.num_rounds = args.num_rounds
self.max_parallel = args.max_parallel
self.output_length = args.output_length
async def handle_request(self, item):
try:
client_id, payload = item
response = await async_request_sglang_generate(payload, self.url, self.pbar)
if self.pbar.n == self.pbar.total:
self.finished_time = time.perf_counter()
self.response_queue.put((client_id, response))
except Exception as e:
print(f"Request failed: {e}")
def request_sender(self):
async def request_loop():
while True:
if self.sent_requests - self.completed_requests < self.max_parallel:
new_request = self.ready_queue.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
self.sent_requests += 1
else:
await asyncio.sleep(0.05)
continue
if self.pbar.n == self.pbar.total:
break
# Calculate Poisson-distributed wait time
if self.distribution == "poisson":
sleep_time = random.expovariate(self.request_rate)
elif self.distribution == "uniform":
avg_interval = (
1.0 / self.request_rate if self.request_rate > 0 else 1.0
)
sleep_time = random.uniform(0, 2 * avg_interval)
else:
raise ValueError("Invalid distribution type")
await asyncio.sleep(sleep_time) # Wait before sending the next request
# Create and run the event loop for asynchronous requests
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(request_loop())
loop.close()
def response_handler(self):
while True:
try:
client_id, response = self.response_queue.get(
timeout=10
) # Block until response is available
if not response.success:
raise ValueError(f"Request failed with error: {response.error}")
self.client_records[client_id]["history"] += response.generated_text
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.performance_metrics["generated_len"].append(response.generated_len)
self.completed_requests += 1
if self.client_records[client_id]["round"] < self.num_rounds:
# append new request to client's history
self.client_records[client_id][
"history"
] += self.sub_question_inputs.pop().prompt
self.ready_queue.append(
(
client_id,
gen_payload(
self.client_records[client_id]["history"],
self.output_length,
),
)
)
except queue.Empty:
if self.pbar.n == self.pbar.total:
break
except ValueError as e:
print(f"Error processing response for client {client_id}: {e}")
continue
def run(self):
request_thread = threading.Thread(target=self.request_sender, daemon=True)
response_thread = threading.Thread(target=self.response_handler, daemon=True)
self.start_time = time.perf_counter()
request_thread.start()
response_thread.start()
request_thread.join()
response_thread.join()
self.pbar.close()
duration = self.finished_time - self.start_time
performance_data = {
"summary": {
"total_requests": len(self.performance_metrics["ttft"]),
"request_rate": self.request_rate,
"average_ttft": sum(self.performance_metrics["ttft"])
/ len(self.performance_metrics["ttft"]),
"p90_ttft": sorted(self.performance_metrics["ttft"])[
int(0.9 * len(self.performance_metrics["ttft"]))
],
"median_ttft": sorted(self.performance_metrics["ttft"])[
len(self.performance_metrics["ttft"]) // 2
],
"average_latency": sum(self.performance_metrics["latency"])
/ len(self.performance_metrics["latency"]),
"p90_latency": sorted(self.performance_metrics["latency"])[
int(0.9 * len(self.performance_metrics["latency"]))
],
"median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2
],
"input_token_throughput": sum(self.performance_metrics["prompt_len"])
/ duration,
"output_token_throughput": sum(
self.performance_metrics["generated_len"]
)
/ duration,
"throughput": self.pbar.total / duration,
"cache_hit_rate": (
0
if sum(self.performance_metrics["prompt_len"]) == 0
else sum(self.performance_metrics["cached_tokens"])
/ sum(self.performance_metrics["prompt_len"])
),
},
}
print("All requests completed")
print("Performance metrics summary:")
print(
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
)
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
print(
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
)
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print(
f" Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second"
)
print(
f" Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second"
)
print(
f" Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
return performance_data
if __name__ == "__main__":
args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
random.seed(args.seed)
np.random.seed(args.seed)
if args.disable_auto_run:
print("Running with specified request rate...")
request_rates = [args.request_rate]
else:
print("Auto-running with different request rates...")
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
for rate in request_rates:
args.request_rate = rate
requests.post(flush_cache_url)
time.sleep(1)
performance_data = WorkloadGenerator(args).run()
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,590 @@
import json
import os
import pickle
import random
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
from nextqa import NExTQALoader
# from nextqa.video import , VideoPrompt
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
from sglang.bench_serving import (
download_and_cache_file,
gen_prompt,
get_gen_prefix_cache_path,
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
# A list of all the conversations. Each conversation is a list of
# tuples. If multiturn is not enabled, the length of list is 1,
# containing only the first Q&A pair.
# For the shared prefix workload (synthetic, loogle, nextqa), it
# is a list of conversations sharing the same prefix (synthetic,
# doc, video)
SampleOutput = List[List[Tuple[MsgContent, int, int]]]
def common_filter_chat(
num_requests: int,
new_dataset: List,
tokenizer: PreTrainedTokenizerBase,
min_prompt_len: Optional[int],
min_output_len: Optional[int],
max_prompt_len: Optional[int],
max_output_len: Optional[int],
fixed_output_len: Optional[int],
) -> SampleOutput:
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = []
l = 0
input_tokens = 0
output_tokens = 0
while l < num_requests:
for i in range(len(new_dataset)):
if l == num_requests:
break
processed = []
for j in new_dataset[i]:
# Tokenize the prompts and completions.
prompt = j[0]
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
completion = j[1]
completion_token_ids = tokenizer.encode(completion)
output_len = (
len(completion_token_ids)
if fixed_output_len is None
else fixed_output_len
)
if (
min_prompt_len is not None
and prompt_len < min_prompt_len
or min_output_len is not None
and output_len < min_output_len
or max_prompt_len is not None
and prompt_len > max_prompt_len
or max_output_len is not None
and output_len > max_output_len
):
# Prune too short sequences.
continue
input_tokens += prompt_len
output_tokens += output_len
processed.append((prompt, prompt_len, output_len))
if len(processed) != 0:
filtered_dataset.append(processed)
l += 1
print(f"#Input tokens: {input_tokens}")
print(f"#Output tokens: {output_tokens}")
return filtered_dataset
def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Keep one conversation in one list
new_dataset = []
for data in dataset:
if len(data["conversations"]) % 2 != 0:
continue
if data["conversations"][0]["from"] != "human":
continue
chat = []
total_len = 2
if enable_multiturn:
total_len = len(data["conversations"])
for i in range(0, total_len, 2):
# One user One Assistant
chat.append(
(
data["conversations"][i]["value"],
data["conversations"][i + 1]["value"],
)
)
new_dataset.append(chat)
if not disable_shuffle:
# Shuffle the dataset.
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
)
return filtered_dataset
def sample_ultrachat_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset
dataset = []
with open(dataset_path) as f:
while True:
line = f.readline()
if not line:
break
dataset.append(json.loads(line))
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["data"]) >= 2]
# Keep one conversation in one list
new_dataset = []
for data in dataset:
if len(data["data"]) % 2 != 0:
continue
chat = []
total_len = 2
if enable_multiturn:
total_len = len(data["data"])
for i in range(0, total_len, 2):
# One user One Assistant
chat.append((data["data"][i], data["data"][i + 1]))
new_dataset.append(chat)
# Shuffle the dataset.
if not disable_shuffle:
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
)
return filtered_dataset
def sample_loogle_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
enable_shared_prefix: bool = False,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset
dataset = []
with open(dataset_path) as f:
while True:
line = f.readline()
if not line:
break
dataset.append(json.loads(line))
# Keep one conversation in one list
new_dataset = []
# TODO: Add shared prefix support for loogle
# NOTE: Now we preprocess it only for chat
for data in dataset:
chat = []
if (
"qa_pairs" not in data
or data["qa_pairs"] == "none"
or len(data["qa_pairs"]) == 0
):
# If Q is none (for summarization),
# We add a question for summarization
# And keep the summary up to 1024 words
chat.append(
(
"Input: "
+ data["input"]
+ " Question: "
+ "Please summarize the input",
data["input"][:1024],
)
)
new_dataset.append(chat)
else:
qa_pairs = eval(data["qa_pairs"])
for i, qa in enumerate(qa_pairs):
if i == 0 or enable_shared_prefix:
# Combine input with the first Q
chat.append(
("Input: " + data["input"] + " Question: " + qa["Q"], qa["A"])
)
elif enable_multiturn:
chat.append((qa["Q"], qa["A"]))
new_dataset.append(chat)
# Shuffle the dataset.
if not disable_shuffle:
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len
)
return filtered_dataset
def sample_nextqa_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
max_frames: int, # Specific for video
model_path: str,
disable_shuffle: bool = False,
enable_multiturn: bool = True, # No multiturn support for now
backend: str = "sglang-oai",
chat_template_name: Optional[str] = None,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
"""
Example of messages:
message = {
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": base64_data}},
{"type": "text", "text": video.prompt},
],
}
"""
if fixed_output_len is None:
fixed_output_len = 4096
# TODO: Check for multiturn
dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames)
new_dataset = []
for v in dataset:
new_dataset.append(v)
if not disable_shuffle:
random.shuffle(new_dataset)
# TODO: prompt len can get from server side
filtered_dataset = []
l = 0
while l < num_requests:
for i in range(len(new_dataset)):
if l == num_requests:
break
video = new_dataset[i]
# text prompt
prompt = video.prompt
# NOTE: Chat Template is a must for video benchmark because we have to
# add special image token for later expansion
if backend == "sglang" or backend == "sglang-native":
if "chat_template" in tokenizer.init_kwargs:
chat_template = get_chat_template(tokenizer.get_chat_template())
elif chat_template_name is not None:
chat_template = get_chat_template(chat_template_name)
else:
chat_template = get_chat_template_by_model_path(model_path)
prompt = chat_template.image_token + prompt
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len # max output len, not real output len
# video input
base64_data = encode_video_base64(video.path, video.num_frames)
# NOTE: This will be replaced by the expanded length from the server
prompt_len += video.num_frames
# add to content
content = [
{"type": "image_url", "image_url": {"url": base64_data}},
{"type": "text", "text": prompt},
]
filtered_dataset.append([(content, prompt_len, output_len)])
l += 1
return filtered_dataset
def sample_random_requests(
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
disable_shuffle: bool = False,
) -> SampleOutput:
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
size=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
)
if True:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
if not disable_shuffle:
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
input_requests: SampleOutput = []
for data in dataset:
i = len(input_requests)
if i == num_prompts:
break
# Tokenize the prompts and completions.
prompt = data[0]
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
# Skip empty prompt
if prompt_len == 0:
continue
if prompt_len > input_lens[i]:
input_ids = prompt_token_ids[: input_lens[i]]
else:
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
prompt = tokenizer.decode(input_ids)
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
else:
# Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
return input_requests
def gen_prompt(tokenizer, token_num):
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
all_available_tokens = list(tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num)
return tokenizer.decode(selected_tokens)
def get_gen_prefix_cache_path(args, tokenizer):
"""Create cache directory under ~/.cache/sglang/benchmark"""
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
# Create a unique cache filename based on the generation parameters
cache_key = (
f"gsp_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
f"{tokenizer.__class__.__name__}.pkl"
)
return cache_dir / cache_key
def sample_generated_shared_prefix_requests(
num_groups: int,
prompts_per_group: int,
system_prompt_len: int,
question_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
args,
disable_shuffle: bool = False,
) -> SampleOutput:
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
cache_path = get_gen_prefix_cache_path(args, tokenizer)
# Try to load from cache first
if cache_path.exists():
print(f"\nLoading cached generated input data from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f)
print("\nGenerating new input data...")
# Generate system prompts for each group
system_prompts = []
for _ in range(num_groups):
system_prompt = gen_prompt(tokenizer, system_prompt_len)
system_prompts.append(system_prompt)
# Generate questions
questions = []
for _ in range(num_groups * prompts_per_group):
question = gen_prompt(tokenizer, question_len)
questions.append(question)
# Combine system prompts with questions
input_requests = []
total_input_tokens = 0
total_output_tokens = 0
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx]
input_requests.append([])
for prompt_idx in tqdm(
range(prompts_per_group), desc="Generating questions", leave=False
):
question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))
input_requests[-1].append((full_prompt, prompt_len, output_len))
total_input_tokens += prompt_len
total_output_tokens += output_len
if not disable_shuffle:
# Shuffle questions
random.shuffle(input_requests)
# Print statistics
print(f"\nGenerated shared prefix dataset statistics:")
print(f"Number of groups: {num_groups}")
print(f"Prompts per group: {prompts_per_group}")
print(f"Total prompts: {len(input_requests) * prompts_per_group}")
print(f"Total input tokens: {total_input_tokens}")
print(f"Total output tokens: {total_output_tokens}")
print(
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
)
print(
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
)
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Caching generated input data to {cache_path}")
with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
return input_requests
def get_dataset(args, tokenizer):
if args.dataset_name == "sharegpt":
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "ultrachat":
input_requests = sample_ultrachat_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "loogle":
input_requests = sample_loogle_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
enable_shared_prefix=args.enable_shared_prefix,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "nextqa":
input_requests = sample_nextqa_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
max_frames=args.max_frames,
model_path=args.model,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
backend=args.backend,
chat_template_name=args.chat_template,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
dataset_path=args.dataset_path,
)
elif args.dataset_name == "generated-shared-prefix":
input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gsp_num_groups,
prompts_per_group=args.gsp_prompts_per_group,
system_prompt_len=args.gsp_system_prompt_len,
question_len=args.gsp_question_len,
output_len=args.gsp_output_len,
args=args,
tokenizer=tokenizer,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
return input_requests

66
benchmark/hicache/download.sh Executable file
View File

@@ -0,0 +1,66 @@
#!/usr/bin/bash
# The usage function
usage() {
echo "Usage: $0 {sharegpt|ultragpt|loogle|nextqa|all}"
exit 1
}
# The download function
download() {
case "$1" in
sharegpt)
echo $1
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
;;
ultragpt)
echo $1
# Questions about the world
wget https://cloud.tsinghua.edu.cn/seafhttp/files/be1d7b87-22ca-449e-a6a7-c61d1ea7e010/ultrachat_release_230407.json
# Writing and Creation
wget https://cloud.tsinghua.edu.cn/seafhttp/files/61742d2a-25e2-4d08-b2b9-15f47ae50ace/ultrachat_material_release_230417.json
wget https://cloud.tsinghua.edu.cn/seafhttp/files/f71f6aa6-d346-4b16-85b7-8502efa3d608/ultrachat_material_release_230412.json
# External materials
wget https://cloud.tsinghua.edu.cn/seafhttp/files/42d22e28-e899-4975-a70f-5eda163e265d/ultrachat_existent_material_release_230420.json.gz
gunzip ultrachat_existent_material_release_230420.json.gz
;;
loogle)
echo $1
git lfs install
git clone git@hf.co:datasets/bigainlco/LooGLE
unzip LooGLE/data.zip
;;
nextqa)
echo $1
git lfs install
git clone https://huggingface.co/datasets/lmms-lab/NExTQA
unzip NExTQA/videos.zip
;;
*)
usage
exit 1
;;
esac
}
# Arg check
if [ "$#" -ne 1 ]; then
usage
fi
# Invoke
case "$1" in
sharegpt|ultragpt|loogle|nextqa)
download "$1"
;;
all)
download sharegpt
download ultragpt
download loogle
download nextqa
;;
*)
usage
;;
esac

159
benchmark/hicache/nextqa.py Normal file
View File

@@ -0,0 +1,159 @@
import os
import sys
from typing import List
import av
from datasets import load_dataset
def find_video_files(video_dir) -> List[str]:
if os.path.isfile(video_dir):
return [video_dir]
video_files = []
for root, dirs, files in os.walk(video_dir):
for file in files:
if file.endswith((".mp4", ".avi", ".mov")):
video_files.append(os.path.join(root, file))
# if file is dir
elif os.path.isdir(file):
video_files.extend(find_video_files(file))
return video_files
def video_frames(video_path, max_frames) -> int:
container = av.open(video_path)
total_frames = container.streams.video[0].frames
return min(total_frames, max_frames)
class Video:
def __init__(self, video_path, num_frames):
self.path = video_path
self.num_frames = num_frames
def __str__(self):
return f"Video({self.path}, {self.num_frames})"
def __iter__(self):
return iter((self.path, self.num_frames))
class VideoPrompt(Video):
def __init__(self, video_path, num_frames, prompt):
super().__init__(video_path, num_frames)
self.prompt = prompt
def __str__(self):
return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})"
def __iter__(self):
return iter((self.path, self.num_frames, self.prompt))
class VideoLoader:
pass
class VideoFileLoader(VideoLoader):
"""
Load all the videos in a directory
"""
def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize):
super().__init__()
self.video_dir = video_dir
self.video_files = find_video_files(video_dir)
self.batch_size = batch_size
self.max_frames = max_frames
print(f"batch_size: {batch_size}, max_frames: {max_frames}")
def __iter__(self): # (file, number of frames)
if self.batch_size == 1:
for video_file in self.video_files:
yield Video(video_file, video_frames(video_file, self.max_frames))
else:
batch = []
for video_file in self.video_files:
video = Video(video_file, video_frames(video_file, self.max_frames))
batch.append(video)
if len(batch) == self.batch_size:
yield batch
batch = []
class NExTQALoader(VideoLoader):
"""
Load vdideos and prompts from NExT dataset
set: train, test or validation
"""
def __init__(
self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE"
):
"""
task: 'MV' or 'OE'
"""
super().__init__()
self.task = task
print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA")
self.ds = load_dataset("lmms-lab/NExTQA", task)
self.ds = self.ds[dset]
# self.n = ds.num_rows
self.video_dir = video_dir
self.video_files = find_video_files(video_dir)
self.video_to_path = dict()
for video_file in self.video_files:
video_id = video_file.split("/")[-1].split(".")[0]
self.video_to_path[video_id] = video_file
self.batch_size = batch_size
self.max_frames = max_frames
def get_video_prompt(self, entry, max_frames) -> VideoPrompt:
# Get video
video_id = entry["video"]
video_path = self.video_to_path[video_id]
assert os.path.exists(video_path), f"Video not found: {video_path}"
num_frames = min(entry["frame_count"], max_frames)
video = Video(video_path, num_frames)
prompt = entry["question"] + "?"
if self.task == "MC": # add choices
prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}'
return VideoPrompt(video_path, num_frames, prompt)
def __iter__(self):
if self.batch_size == 1:
for entry in self.ds:
yield self.get_video_prompt(entry, self.max_frames)
else:
batch = []
for entry in self.ds:
video = self.get_video_prompt(entry, self.max_frames)
batch.append(video)
if len(batch) == self.batch_size:
yield batch
batch = []
# main
if __name__ == "__main__":
video_dir = "./videos"
# video_loader = VideoFileLoader(video_dir, batch_size=16)
# for batch in video_loader:
# print(f"Number of videos in batch: {len(batch)}")
# for video_file, num_frames in batch:
# print(f"Video: {video_file} number of frames: {num_frames}")
video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE")
for batch in video_loader:
print(f"Number of videos in batch: {len(batch)}")
for video_file, num_frames, prompt in batch:
print(
f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}"
)
# break
# for video_file, prompt in batch:
# print(f"Video: {video_file} prompt: {prompt}")
# break

View File

@@ -0,0 +1,60 @@
## Run benchmark
### Build dataset
```
pip install wikipedia
python3 build_dataset.py
```
### Dependencies
```
llama_cpp_python 0.2.19
guidance 0.1.10
vllm 0.2.5
outlines 0.0.22
```
### Benchmark sglang
Run Llama-7B
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Run Mixtral-8x7B
```
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
```
Benchmark
```
python3 bench_sglang.py --num-questions 10
```
### Benchmark Outlines + vLLM
Run Llama-7B
```
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Benchmark
```
python3 bench_other.py --backend outlines --num-questions 10
```
### Benchmark guidance
Run Llama-7B and benchmark
```
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```

View File

@@ -0,0 +1,98 @@
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from tqdm import tqdm
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
# fmt: off
def json_decode(document, generate):
s = "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += "{\n"
s += ' "name": '
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": '
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": '
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": '
s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": '
s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
return s
# fmt: on
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[: args.num_questions])):
arguments.append(
{
"document": lines[i]["document"],
}
)
states = [None] * len(arguments)
# Select backend
call_generate = partial(get_call_generate(args), temperature=0)
# Run requests
def get_one_answer(i):
states[i] = json_decode(generate=call_generate, **arguments[i])
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(arguments)))),
total=len(arguments),
)
)
for _ in rets:
pass
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "json_decode_regex",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,101 @@
import argparse
import json
import time
import sglang as sgl
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
# fmt: off
@sgl.function
def json_warm_up(s):
s += "The information about Hogwarts is in the following JSON format.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
print(f'The warmp up json result is:\n{s["json_output"]}')
# fmt: on
# fmt: off
@sgl.function
def json_decode(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
# fmt: on
def main(args):
lines = read_jsonl(args.data_path)
lines = list(lines)
arguments = []
for i in range(len(lines[: args.num_questions])):
arguments.append(
{
"document": lines[i]["document"],
}
)
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Warm up
json_warm_up.run().sync()
# Run requests
tic = time.perf_counter()
states = json_decode.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
for state in states:
fout.write(state["json_output"] + "\n")
with open(args.result_file, "a") as fout:
value = {
"task": "json_decode_regex",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,58 @@
import json
import transformers
import wikipedia
model_path = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(model_path)
city_names = [
"los angles",
"london",
"tokyo",
"beijing",
"singapore",
"paris",
"dubai",
"sydney",
"moscow",
"rome",
"toronto",
"rio de janeiro",
"istanbul",
"berlin",
"auckland",
"buenos aires",
"mexico city",
"mumbai",
"seoul",
"bangkok",
"cairo",
"athens",
"jerusalem",
]
def get_content(city_name):
content = str(wikipedia.page(city_name).content)
content = content.replace("\n\n", "\n")
tokens = t.encode(content)
expected_tokens = 3000
truncate_len = int((expected_tokens / len(tokens)) * len(content))
truncate_content = content[:truncate_len]
truncate_tokens = t.encode(truncate_content)
# Count token
print(
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
)
return truncate_content
if __name__ == "__main__":
with open("questions.jsonl", "w") as fout:
for city_name in city_names:
truncate_content = get_content(city_name)
fout.write(json.dumps({"document": truncate_content}) + "\n")

View File

@@ -0,0 +1,88 @@
## Run benchmark
### Dependencies
```
llama_cpp_python 0.2.38
guidance 0.1.10
vllm 0.2.7
outlines 0.0.25
```
### Build dataset
When benchmarking long document information retrieval, run the following command to build the dataset:
```bash
pip install wikipedia
python3 build_dataset.py
```
### Benchmark sglang
Run Llama-7B
```bash
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Benchmark Character Generation
```bash
python3 bench_sglang.py --mode character
```
Benchmark City Information Retrieval
```bash
python3 bench_sglang.py --mode city
```
### Benchmark Outlines + vLLM
Run Llama-7B
```bash
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Benchmark Character Generation
```bash
python3 bench_other.py --mode character --backend outlines
```
Benchmark City Information Retrieval
```bash
python3 bench_other.py --mode city --backend outlines
```
### Benchmark guidance
Run Llama-7B and benchmark character generation
```bash
python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
Run Llama-7B and benchmark city information retrieval
```bash
python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
Run Llama-7B and benchmark character generation
```
python3 bench_other.py --mode character --backend lmql --parallel 1
```
Run Llama-7B and benchmark city information retrieval
```
python3 bench_other.py --mode city --backend lmql --parallel 1
```

View File

@@ -0,0 +1,288 @@
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import guidance
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
city_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "country": "[\w\d\s]{1,16}",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
+ r""" "population": [-+]?[0-9]{1,9},\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
+ r"""\}"""
)
# fmt: off
def character_gen(name, generate):
s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += generate(s, max_tokens=256, regex=character_regex)
return s
# fmt: on
# fmt: off
def city_gen(document, generate):
s = "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += generate(s, max_tokens=256, regex=city_regex)
return s
# fmt: on
@guidance
def character_maker(lm, name):
regex_str_no_quote = r"[\w\d\s]+"
regex_float = r"[0-9]+\.[0-9]+"
lm += f"""\
{name} is a character in Harry Potter. Please fill in the following information about this character.
{{
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
"house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
"blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
"occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
"wand": {{
"wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
"core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
"length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
}},
"alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
"patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
"bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
}}
"""
return lm
async def call_generate_lmql(
prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
):
assert model is not None
import lmql
@lmql.query(model=model)
async def program(question, max_tokens, regex):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
return ANSWER
'''
return await program(
question=prompt,
temperature=temperature,
max_tokens=max_tokens,
max_len=max_len,
regex=regex,
**kwargs,
)
@guidance
def city_maker(lm, document):
regex_str_no_quote = r"[\w\d\s]+"
regex_float = r"[0-9]+\.[0-9]+"
lm += f"""\
Please extract the information of a city from the following wikipedia page.
Page begin.
{document}
Page end.
Here is the name, country, and symbol of the city in JSON format.
{{
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
"country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
"latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
"population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
"top 3 landmarks": [
"{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
]
}}
"""
return lm
def bench_character(args):
arguments = []
with open(args.data_path, "r") as f:
for line in f:
arguments.append({"name": line.strip()})
arguments = arguments[: args.num_jsons]
states = [None] * len(arguments)
# Select backend
if args.backend == "outlines":
call_generate = partial(get_call_generate(args), temperature=0)
def get_one_answer(i):
states[i] = character_gen(**arguments[i], generate=call_generate)
elif args.backend == "guidance":
model = guidance.models.LlamaCpp(
args.model_path,
n_gpu_layers=-1,
n_ctx=args.n_ctx,
)
def get_one_answer(i):
lm = model + character_maker(**arguments[i])
states[i] = lm
elif args.backend == "lmql":
import asyncio
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
call_generate = partial(
call_generate_lmql,
model=model,
max_tokens=256,
regex=character_regex,
)
async def get_one_answer_async(i):
states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)
else:
raise ValueError(f"Invalid backend: {args.backend}")
tic = time.perf_counter()
if args.backend != "lmql":
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(arguments)))),
total=len(arguments),
)
)
for _ in rets:
pass
else:
batches = []
for i in range(0, len(arguments), args.parallel):
batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
loop = asyncio.get_event_loop()
for bt in tqdm(batches):
loop.run_until_complete(
asyncio.gather(*[get_one_answer_async(i) for i in bt])
)
latency = time.perf_counter() - tic
return states, latency
def bench_city_doc(args):
arguments = []
for line in read_jsonl(args.data_path):
arguments.append({"document": line["document"]})
arguments = arguments[: args.num_jsons]
states = [None] * len(arguments)
# Select backend
if args.backend == "outlines":
call_generate = partial(get_call_generate(args), temperature=0)
def get_one_answer(i):
states[i] = city_gen(**arguments[i], generate=call_generate)
elif args.backend == "guidance":
model = guidance.models.LlamaCpp(
args.model_path,
n_gpu_layers=-1,
n_ctx=args.n_ctx,
)
def get_one_answer(i):
lm = model + city_maker(**arguments[i])
states[i] = lm
else:
raise ValueError(f"Invalid backend: {args.backend}")
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(arguments))))
for _ in rets:
pass
latency = time.perf_counter() - tic
return states, latency
def main(args):
if args.mode == "character":
args.data_path = "dataset.txt"
states, latency = bench_character(args)
elif args.mode == "city":
args.data_path = "questions.jsonl"
states, latency = bench_city_doc(args)
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "json_jump_forward",
"backend": args.backend,
"latency": round(latency, 3),
"num_jsons": args.num_jsons,
"mode": args.mode,
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str)
parser.add_argument("--num-jsons", type=int, default=50)
parser.add_argument(
"--mode", type=str, default="character", choices=["character", "city"]
)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,143 @@
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
city_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "country": "[\w\d\s]{1,16}",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
+ r""" "population": [-+]?[0-9]{1,9},\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
+ r"""\}"""
)
# fmt: off
@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on
# fmt: off
@sgl.function
def city_gen(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
# fmt: on
def bench_city_doc(args):
arguments = []
for line in read_jsonl(args.data_path):
arguments.append({"document": line["document"]})
arguments = arguments[: args.num_jsons]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.perf_counter()
states = city_gen.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
return states, latency
def bench_character(args):
arguments = []
with open(args.data_path, "r") as f:
for line in f:
arguments.append({"name": line.strip()})
arguments = arguments[: args.num_jsons]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.perf_counter()
states = character_gen.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
return states, latency
def main(args):
if args.mode == "character":
args.data_path = "dataset.txt"
states, latency = bench_character(args)
elif args.mode == "city":
args.data_path = "questions.jsonl"
states, latency = bench_city_doc(args)
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
with open(f"{args.backend}_{args.mode}.json", "w") as fout:
for state in states:
fout.write(state["json_output"] + "\n")
with open(args.result_file, "a") as fout:
value = {
"task": "json_jump_forward",
"backend": args.backend,
"latency": round(latency, 3),
"num_jsons": args.num_jsons,
"mode": args.mode,
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str)
parser.add_argument("--num-jsons", type=int, default=50)
parser.add_argument(
"--mode", type=str, default="character", choices=["character", "city"]
)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,58 @@
import json
import transformers
import wikipedia
model_path = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(model_path)
city_names = [
"los angles",
"london",
"tokyo",
"beijing",
"singapore",
"paris",
"dubai",
"sydney",
"moscow",
"rome",
"toronto",
"rio de janeiro",
"istanbul",
"berlin",
"auckland",
"buenos aires",
"mexico city",
"mumbai",
"seoul",
"bangkok",
"cairo",
"athens",
"jerusalem",
]
def get_content(city_name):
content = str(wikipedia.page(city_name).content)
content = content.replace("\n\n", "\n")
tokens = t.encode(content)
expected_tokens = 3000
truncate_len = int((expected_tokens / len(tokens)) * len(content))
truncate_content = content[:truncate_len]
truncate_tokens = t.encode(truncate_content)
# Count token
print(
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
)
return truncate_content
if __name__ == "__main__":
with open("questions.jsonl", "w") as fout:
for city_name in city_names:
truncate_content = get_content(city_name)
fout.write(json.dumps({"document": truncate_content}) + "\n")

View File

@@ -0,0 +1,50 @@
Harry Potter
Hermione Granger
Ron Weasley
Albus Dumbledore
Severus Snape
Rubeus Hagrid
Draco Malfoy
Ginny Weasley
Fred Weasley
George Weasley
Percy Weasley
Sirius Black
Remus Lupin
Neville Longbottom
Luna Lovegood
Cedric Diggory
Cho Chang
Lord Voldemort
Minerva McGonagall
Filius Flitwick
Dolores Umbridge
Bellatrix Lestrange
Lucius Malfoy
Molly Weasley
Arthur Weasley
Nymphadora Tonks
Dobby
Moaning Myrtle
Peter Pettigrew
Alastor 'Mad-Eye' Moody
Horace Slughorn
Vernon Dursley
Petunia Dursley
Dudley Dursley
Argus Filch
Sybill Trelawney
Gilderoy Lockhart
Fleur Delacour
Viktor Krum
Bill Weasley
Oliver Wood
Cornelius Fudge
Barty Crouch Sr.
Barty Crouch Jr.
Kingsley Shacklebolt
Quirinus Quirrell
Nearly Headless Nick
Aunt Marge
Griphook
Ludo Bagman

View File

@@ -0,0 +1,15 @@
## Run benchmark
### Benchmark sglang
Run Llama-8b
```bash
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
```
Benchmark
```bash
python3 bench_sglang.py
```

View File

@@ -0,0 +1,146 @@
import argparse
import json
import time
from typing import List, Tuple
import jsonschema
from datasets import load_dataset
import sglang as sgl
from sglang.global_config import global_config
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
@sgl.function
def schema_gen(s, message: Tuple[str, str], json_schema: str):
system, user = message
s += sgl.system(system)
s += sgl.user(user)
s += sgl.assistant(
sgl.gen("json_output", temperature=0, max_tokens=256, json_schema=json_schema)
)
def contains_formats(schema, formats: List[str]):
if isinstance(schema, dict):
if schema.get("format", None) in formats:
return True
for value in schema.values():
if contains_formats(value, formats):
return True
elif isinstance(schema, list):
for item in schema:
if contains_formats(item, formats):
return True
return False
def convert_dataset(path: str):
raw_dataset = load_dataset(path)
dataset = []
for data in raw_dataset["train"]:
messages = data["prompt"]
schema = data["schema"]
obj = json.loads(schema)
# skip some corrupted examples
if obj.get("type", None) is None:
continue
# skip schema with format "email"
# which is not supported by outlines for now
if contains_formats(obj, ["email"]):
continue
system = messages[0]
user = messages[1]
assert system["role"] == "system", "invalid role"
assert user["role"] == "user", "invalid role"
assert len(messages) == 2, "invalid message length"
message = json.dumps(system["content"]), json.dumps(user["content"])
dataset.append(
{
"message": message,
"json_schema": schema,
}
)
return dataset
def bench_schema(args):
arguments = convert_dataset(args.data_path)
if args.num_jsons < 0 or args.num_jsons > len(arguments):
args.num_jsons = len(arguments)
arguments = arguments[: args.num_jsons]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.perf_counter()
states = schema_gen.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
# Check if the outputs are valid
indexes = []
for i, state in enumerate(states):
try:
schema = json.loads(arguments[i]["json_schema"])
obj = json.loads(state["json_output"])
assert jsonschema.validate(obj, schema) is None
except Exception as e:
print(e)
indexes.append(i)
return states, latency
def main(args):
states, latency = bench_schema(args)
# Compute accuracy
tokenizer = get_tokenizer(
global_config.default_backend.get_server_info()["tokenizer_path"]
)
output_jsons = [state["json_output"] for state in states]
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
print(f"Latency: {latency:.3f}")
print(f"Output throughput: {num_output_tokens / latency:.3f} token/s")
print(f"#output tokens: {num_output_tokens}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(f"{args.backend}.jsonl", "w") as fout:
for state in states:
fout.write(state["json_output"] + "\n")
with open(args.result_file, "a") as fout:
value = {
"task": "json_schema",
"backend": args.backend,
"latency": round(latency, 3),
"num_jsons": args.num_jsons,
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="NousResearch/json-mode-eval")
parser.add_argument("--num-jsons", type=int, default=-1)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,224 @@
"""For Now, MSCCL is only supported on TP16 and TP8 case
export WORLD_SIZE=1
export RANK=0
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=12345
torchrun --nproc_per_node gpu \
--nnodes $WORLD_SIZE \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py
"""
import os
from contextlib import nullcontext
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group,
graph_capture,
initialize_model_parallel,
set_mscclpp_all_reduce,
)
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
dist.all_reduce(torch_input, group=group)
return torch_input
def msccl_allreduce(
msccl_input: torch.Tensor, msccl_comm: PyMscclppCommunicator
) -> torch.Tensor:
return msccl_comm.all_reduce(msccl_input)
def pynccl_allreduce(
msccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
) -> torch.Tensor:
pynccl_comm.all_reduce(msccl_input)
return msccl_input
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
graph_input = inp_randn.clone()
with graph_capture() as graph_capture_context:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for _ in range(graph_loop):
graph_out = func(graph_input)
graph.replay()
func_output = graph_out.clone()
for _ in range(warmup_loop):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for _ in range(test_loop):
torch.cuda.synchronize()
dist.barrier()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
graph.reset()
return func_output, func_cost_us
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
eager_input = inp_randn.clone()
eager_output = func(eager_input)
func_output = eager_output.clone()
for _ in range(warmup_loop):
func(eager_input)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(test_loop):
func(eager_input)
end_event.record()
torch.cuda.synchronize()
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
return func_output, func_cost_us
def get_torch_prof_ctx(do_prof: bool):
ctx = (
torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_stack=True,
)
if do_prof
else nullcontext()
)
return ctx
def human_readable_size(size, decimal_places=1):
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
if size < 1024.0 or unit == "PiB":
break
size /= 1024.0
return f"{size:.{decimal_places}f} {unit}"
try:
from tabulate import tabulate
except ImportError:
print("tabulate not installed, skipping table printing")
tabulate = None
def print_markdown_table(data):
if tabulate is not None:
print(tabulate(data, headers="keys", tablefmt="github"))
return
headers = data[0].keys()
header_row = "| " + " | ".join(headers) + " |"
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
rows = []
for item in data:
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
rows.append(row)
markdown_table = "\n".join([header_row, separator] + rows)
print(markdown_table)
if __name__ == "__main__":
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
world, world_size = dist.group.WORLD, dist.get_world_size()
rank = dist.get_rank()
torch.cuda.set_device(rank % 8)
device = torch.cuda.current_device()
set_mscclpp_all_reduce(True)
init_distributed_environment(
world_size=world_size,
rank=rank,
local_rank=rank % 8,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
cpu_group = get_tensor_model_parallel_group().cpu_group
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
pymscclpp_comm = get_tensor_model_parallel_group().pymscclpp_comm
dist.barrier()
profile = False
dtype = torch.bfloat16
ctx = get_torch_prof_ctx(profile)
result = []
with ctx:
for i in range(10, 20):
sz = 2**i
if sz * dtype.itemsize > 2**20:
break
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
memory = torch.empty_like(inp_randn)
memory_out = torch.empty_like(memory)
torch_eager_output, torch_eager_time = _bench_eager_time(
lambda inp: torch_allreduce(inp, group), inp_randn
)
msccl_eager_output, msccl_eager_time = _bench_eager_time(
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
)
msccl_graph_output, msccl_graph_time = _bench_graph_time(
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
)
# since pynccl is inplace op, this return result is not correct if graph loop > 1
_, pynccl_graph_time = _bench_graph_time(
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
)
torch.testing.assert_close(torch_eager_output, msccl_graph_output)
torch.testing.assert_close(torch_eager_output, msccl_eager_output)
result.append(
{
"msg_size": human_readable_size(inp_randn.nbytes),
"torch eager time": torch_eager_time,
"msccl eager time": msccl_eager_time,
"msccl graph time": msccl_graph_time,
"pynccl graph time": pynccl_graph_time,
}
)
if rank == 0:
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
if rank == 0:
print_markdown_table(result)
if profile:
prof_dir = f"prof/msccl"
os.makedirs(prof_dir, exist_ok=True)
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")

View File

@@ -0,0 +1,403 @@
import itertools
import math
import cudnn
import torch
import torch.utils.benchmark as benchmark
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
from sglang.srt.utils import should_use_tensor_core
def benchmark_forward(
fn,
*inputs,
repeats=10,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt="fn_amp(*inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
return t, m
def time_fwd(func, *args, **kwargs):
time_f = benchmark_forward(func, *args, **kwargs)
return time_f[1].mean * 1e6
def decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits,
warmup=10,
):
k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
o = torch.empty_like(q)
total_tokens = batch_size * kv_len
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
max_len_in_batch = kv_len
sm_scale = 1.0 / (head_dim**0.5)
attn_logits = torch.empty(
(batch_size, head_num_q, num_kv_splits, head_dim + 1),
dtype=torch.float32,
device="cuda",
)
for _ in range(warmup):
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
f = time_fwd(
decode_attention_fwd,
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
return f, o
def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=dtype,
num_attention_heads=head_num_q,
num_kv_heads=head_num_kv,
)
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
)
class FlashinferAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
dtype,
warmup=10,
):
total_tokens = batch_size * kv_len
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full(
(batch_size,), 1, dtype=torch.int32, device="cuda"
)
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
head_num_q,
head_num_kv,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=dtype,
)
for _ in range(warmup):
o = flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, head_num_q, head_dim), kv_data
)
f = time_fwd(
flashinfer_decode_wrapper.forward,
q.contiguous().view(-1, head_num_q, head_dim),
kv_data,
)
return f, o
return FlashinferAttention
def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
else:
raise ValueError("Unsupported tensor data type.")
def decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
):
# Prepare data: continuous q,k,v
dims_q = (batch_size, head_num_q, 1, head_dim)
strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
q_gpu = q.as_strided(dims_q, strides_q)
o_gpu = (
torch.empty(batch_size * head_num_q * head_dim)
.half()
.cuda()
.as_strided(dims_q, strides_q)
)
dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
strides_kv = (
kv_len * head_num_kv * head_dim,
head_dim,
head_num_kv * head_dim,
1,
)
k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)
seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
attn_scale = 1.0 / (head_dim**0.5)
# Prepare data: paged k,v
block_size = 1
blocks_per_batch = math.ceil(kv_len / block_size)
# [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
page_table_k_gpu = (
torch.linspace(
0,
batch_size * blocks_per_batch - 1,
batch_size * blocks_per_batch,
device="cuda",
dtype=torch.int32,
)
.reshape(blocks_per_batch, 1, batch_size, 1)
.transpose(0, 2)
)
page_table_v_gpu = page_table_k_gpu.clone()
graph = cudnn.pygraph(
io_data_type=convert_to_cudnn_type(dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q = graph.tensor_like(q_gpu)
container_k = graph.tensor_like(container_k_gpu)
container_v = graph.tensor_like(container_v_gpu)
page_table_k = graph.tensor_like(page_table_k_gpu)
page_table_v = graph.tensor_like(page_table_v_gpu)
seq_len_q = graph.tensor_like(seq_len_q_gpu)
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)
o, _ = graph.sdpa(
name="sdpa",
q=q,
k=container_k, # Container K: non contiguous container with K blocks
v=container_v, # Container V: non contiguous container with V blocks
is_inference=True,
attn_scale=attn_scale,
use_causal_mask=False,
use_padding_mask=True,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks
paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks
paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended)
)
o.set_output(True).set_dim(dims_q).set_stride(strides_q)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A])
graph.check_support()
graph.build_plans()
workspace = torch.empty(
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
)
variant_pack = {
q: q_gpu,
container_k: container_k_gpu,
container_v: container_v_gpu,
page_table_k: page_table_k_gpu,
page_table_v: page_table_v_gpu,
seq_len_q: seq_len_q_gpu,
seq_len_kv: seq_len_kv_gpu,
o: o_gpu,
}
for _ in range(warmup):
graph.execute(variant_pack, workspace)
f = time_fwd(
graph.execute,
variant_pack,
workspace,
)
return f, o_gpu.squeeze(dim=2)
def calculate_diff():
dtype = torch.float16
batch_size = 64
kv_len = 4096
head_num_q = 64
head_num_kv = 8
head_dim = 128
q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
kv_data = (
torch.randn(
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
),
torch.randn(
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
),
)
_, output_sglang = decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits=8,
)
attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
_, output_flashinfer = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
_, output_cudnn = decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
print(f"SGLang output={output_sglang}")
print(f"FlashInfer output={output_flashinfer}")
print(f"cuDNN output={output_cudnn}")
if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
print("✅ SGLang[Triton] and FlashInfer match")
else:
print("❌ SGLang[Triton] and FlashInfer differ")
if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
print("✅ SGLang[Triton] and cuDNN match")
else:
print("❌ SGLang[Triton] and cuDNN differ")
if __name__ == "__main__":
calculate_diff()
head_dim = 128
dtype = torch.float16
batch_size_range = [2**i for i in range(0, 8, 2)]
kv_len_range = [2**i for i in range(6, 13, 1)]
configs = list(itertools.product(batch_size_range, kv_len_range))
for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
attn_flashinfer = decode_attention_flashinfer(
dtype, head_num_q, head_num_kv
).apply
for batch_size, kv_len in configs:
q = torch.randn(
batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
)
kv_data = (
torch.randn(
batch_size * kv_len,
head_num_kv,
head_dim,
dtype=dtype,
device="cuda",
),
torch.randn(
batch_size * kv_len,
head_num_kv,
head_dim,
dtype=dtype,
device="cuda",
),
)
us_cudnn, output_cudnn = decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
us_sglang, output_sglang = decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits=8,
)
us_flashinfer, _ = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
print(
head_num_q,
" ",
head_num_kv,
" ",
batch_size,
" ",
kv_len,
" ",
us_cudnn,
" ",
us_sglang,
" ",
us_flashinfer,
)

View File

@@ -0,0 +1,218 @@
# ADAPTED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py
import os
import sys
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
def init_dist(local_rank: int, num_local_ranks: int, args):
ip = args.master_addr
port = args.master_port
num_nodes = args.nnodes
node_rank = args.node_rank
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{ip}:{port}",
world_size=num_nodes * num_local_ranks,
rank=node_rank * num_local_ranks + local_rank,
)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.cuda.set_device(local_rank)
return (
dist.get_rank(),
dist.get_world_size(),
dist.new_group(list(range(num_local_ranks * num_nodes))),
)
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double() + 1, y.double() + 1
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return (1 - sim).item()
def per_token_cast_to_fp8(x: torch.Tensor):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n
), (x_amax / 448.0).view(m, -1)
def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)
def inplace_unique(x: torch.Tensor, num_slots: int):
assert x.dim() == 2
mask = x < 0
x_padded = x.masked_fill(mask, num_slots)
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
bin_count = bin_count[:, :num_slots]
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
x[:, :].fill_(-1)
valid_len = min(num_slots, x.size(1))
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]
def create_grouped_scores(
scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int
):
num_tokens, num_experts = scores.shape
scores = scores.view(num_tokens, num_groups, -1)
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
return (scores * mask).view(num_tokens, num_experts)
def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
# Warmup
for _ in range(num_warmups):
fn()
# Flush L2
cache.zero_()
# Testing
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
# Record
start_events[i].record()
fn()
end_events[i].record()
if post_fn is not None:
post_fn()
torch.cuda.synchronize()
times = np.array(
[s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]
)[1:]
return np.average(times), np.min(times), np.max(times)
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def bench_kineto(
fn,
kernel_names,
num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: Optional[str] = None,
barrier_comm_profiling: bool = False,
):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
) as prof:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda"))
for _ in range(num_tests):
fn()
prof.step()
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
prof_lines = (
prof.key_averages()
.table(sort_by="cuda_time_total", max_name_column_width=100)
.split("\n")
)
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert (
sum([name in line for line in prof_lines]) == 1
), f"Errors of the kernel {name} in the profiling table"
# Save chrome traces
if trace_path is not None:
prof.export_chrome_trace(trace_path)
# Return average kernel times
units = {"ms": 1e3, "us": 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, "")) / scale)
break
break
return tuple(kernel_times) if is_tupled else kernel_times[0]
def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()

View File

@@ -0,0 +1,476 @@
# MODIFIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/test_internode.py
"""
Example usage:
python tuning_deepep.py --nnodes 4 --node-rank $MY_NODE_RANK --master-addr 1.2.3.4
Then check `deepep_tuned.json`
"""
import argparse
import json
import time
from copy import deepcopy
from pathlib import Path
# noinspection PyUnresolvedReferences
import deep_ep
import torch
import torch.distributed as dist
from deepep_utils import (
bench,
calc_diff,
create_grouped_scores,
init_dist,
inplace_unique,
per_token_cast_back,
per_token_cast_to_fp8,
)
def test_main(
num_sms: int,
local_rank: int,
num_local_ranks: int,
num_ranks: int,
num_nodes: int,
rank: int,
buffer: deep_ep.Buffer,
group: dist.ProcessGroup,
args,
):
# Settings
num_tokens, hidden, num_topk_groups, num_topk, num_experts = (
4096,
7168,
min(num_nodes, 4),
8,
(256 // num_ranks) * num_ranks,
)
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
flush=True,
)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
x_e4m3 = per_token_cast_to_fp8(x)
scores = (
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
+ 1
)
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(
group_scores, k=num_topk_groups, dim=-1, sorted=False
).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[
1
]
topk_weights = (
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
)
topk_weights_pure_rand = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
)
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes)
# RDMA dispatch counts
rdma_idx = topk_idx // (num_experts // num_nodes)
rdma_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rdma_idx, num_nodes)
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
token_idx_in_rank = torch.full(
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
)
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(
count, dtype=torch.long, device="cuda"
)
for i in range(num_nodes):
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
(
ref_num_tokens_per_rank,
ref_num_tokens_per_rdma_rank,
ref_num_tokens_per_expert,
ref_is_token_in_rank,
_,
) = buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
print("", flush=True)
group.barrier()
time.sleep(1)
# Config
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, recv_gbl_rank_prefix_sum):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = recv_gbl_rank_prefix_sum[i].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
flush=True,
end="",
)
dispatch_args = {
"x": current_x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": config,
"async_finish": async_mode,
}
if with_topk:
dispatch_args.update(
{
"topk_idx": topk_idx,
"topk_weights": (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
),
}
)
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
event,
) = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
# Checks
recv_gbl_rank_prefix_sum = handle[-4]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
0
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
assert (
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
== recv_num_tokens_per_expert_list
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
if with_topk:
# Check `topk_idx`
assert (
recv_topk_idx.eq(-1)
| (
(recv_topk_idx >= 0)
& (recv_topk_idx < (num_experts // num_ranks))
)
).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = (
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
recv_topk_weights
)[recv_topk_idx.eq(-1)]
)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {
"x": current_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {
"x": recv_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if with_topk:
combine_args.update({"topk_weights": recv_topk_weights})
if previous_mode:
combine_args.update({"previous_event": buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(
**combine_args
)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(
dim=1
).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = (
combined_topk_weights
if (current_x is x_pure_rand)
else (
combined_topk_weights
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
)
)
ref_topk_weights = (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
)
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
if local_rank == 0:
print(" passed", flush=True)
if local_rank == 0:
print("", flush=True)
output_data = {}
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
rdma_send_bytes = (
(dispatch_bf16_rdma_send_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_rdma_send_bytes
)
nvl_recv_bytes = (
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_nvl_recv_bytes
)
for nvl_chunk_size in range(4, 33, 4):
for rdma_chunk_size in range(4, 33, 4):
config_kwargs = {
"num_sms": num_sms,
"num_max_nvl_chunked_send_tokens": nvl_chunk_size,
"num_max_nvl_chunked_recv_tokens": nvl_buffer_size,
"num_max_rdma_chunked_send_tokens": rdma_chunk_size,
"num_max_rdma_chunked_recv_tokens": rdma_buffer_size,
}
config = deep_ep.Config(**config_kwargs)
tune_args = {"x": current_x, "handle": handle, "config": config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (
num_sms,
nvl_chunk_size,
rdma_chunk_size,
config_kwargs,
)
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if local_rank == 0:
print(
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
flush=True,
)
print("", flush=True)
is_fp8 = isinstance(current_x, tuple)
if is_fp8:
output_data["normal_dispatch"] = deepcopy(best_results[3])
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor(
[best_results[0], best_results[1], best_results[2]],
dtype=torch.int32,
device="cuda",
)
all_best_fp8_results_list = [
torch.zeros_like(best_dispatch_results)
for _ in range(torch.distributed.get_world_size())
]
dist.all_gather(
all_best_fp8_results_list, best_dispatch_results, group=group
)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(
best_dispatch_results[0],
best_dispatch_results[1],
nvl_buffer_size,
best_dispatch_results[2],
rdma_buffer_size,
)
dispatch_args = {
"x": x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": dispatch_config if dispatch_config is not None else config,
}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
for rdma_chunk_size in range(8, 33, 4):
config_kwargs = {
"num_sms": num_sms,
"num_max_nvl_chunked_send_tokens": nvl_chunk_size,
"num_max_nvl_chunked_recv_tokens": nvl_buffer_size,
"num_max_rdma_chunked_send_tokens": rdma_chunk_size,
"num_max_rdma_chunked_recv_tokens": rdma_buffer_size,
}
config = deep_ep.Config(**config_kwargs)
tune_args = {"x": recv_x, "handle": handle, "config": config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if t < best_time:
best_time, best_results = t, (
num_sms,
nvl_chunk_size,
rdma_chunk_size,
config_kwargs,
)
if local_rank == 0:
print(
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
flush=True,
)
print("", flush=True)
output_data["normal_combine"] = deepcopy(best_results[3])
if rank == 0 and local_rank == 0:
_write_output(args, output_data)
def _write_output(args, output_data):
text = json.dumps(output_data, indent=4)
output_path = args.output_path
print(f"Write to {output_path} with {text}")
Path(output_path).write_text(text)
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int, args):
num_nodes = args.nnodes
rank, num_ranks, group = init_dist(local_rank, num_local_ranks, args)
num_sms = args.num_sms
num_qps_per_rank = num_sms // 2
buffer = deep_ep.Buffer(
group,
int(1e9),
int(1e9),
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (num_sms,):
test_main(
i,
local_rank,
num_local_ranks,
num_ranks,
num_nodes,
rank,
buffer,
group,
args,
)
if local_rank == 0:
print("", flush=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-sms", type=int, default=24)
parser.add_argument("--output-path", type=str, default="deepep_tuned.json")
parser.add_argument("--nnodes", type=int, default=1)
parser.add_argument("--node-rank", type=int, default=0)
parser.add_argument("--master-addr", type=str, default="127.0.0.1")
parser.add_argument("--master-port", type=int, default=8361)
args = parser.parse_args()
print(f"Start system with {args=}")
num_processes = 8
torch.multiprocessing.spawn(
test_loop, args=(num_processes, args), nprocs=num_processes
)

View File

@@ -0,0 +1,19 @@
## DeepSeek kernels benchmark
### Prerequisites
- You should install [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) from source before run `benchmark_deepgemm_fp8_gemm.py` and `benchmark_deepgemm_fp8_group_gemm.py`.
### Benchmark
- `benchmark_deepgemm_fp8_gemm.py`
```bash
python benchmark_deepgemm_fp8_gemm.py --run_correctness --tp_size 1
```
- `benchmark_deepgemm_fp8_group_gemm.py`
```bash
python benchmark_deepgemm_fp8_group_gemm.py --run_correctness --tp_size 1
```
- You can use the `--run_correctness` parameter to verify all kernels results's correctness.
- You can use the `--tp_size` parameter to benchmark all FP8 w8a8 block-wise matrix multiplications involved in DeepSeek V3/R1 under the current tensor parallelism (TP) setting. This benchmark compares DeepSeek's open-source [DeepGemm](https://github.com/deepseek-ai/DeepGEMM) implementation with SGLang's and VLLM Triton implementation.

View File

@@ -0,0 +1,400 @@
from typing import Tuple
import deep_gemm
import tilelang
import tilelang.language as T
import torch
import triton
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
)
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
def tl_gemm(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"e4m3_float8",
], "Currently only e4m3_float8 is supported"
assert out_dtype in [
"bfloat16",
"float16",
], "Currently only bfloat16 and float16 are supported"
TILE_SIZE = (128, 128, 128)
block_M = TILE_SIZE[0]
block_N = TILE_SIZE[1]
block_K = TILE_SIZE[2]
A_shape = (M, K)
Scales_A_shape = (M, T.ceildiv(K, block_K))
B_shape = (N, K)
Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (block_M, block_N)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
scales_a: T.Buffer(Scales_A_shape, "float32"),
B: T.Buffer(B_shape, in_dtype),
scales_b: T.Buffer(Scales_B_shape, "float32"),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), "float32")
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n
), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2)
)
def fp8_gemm_deepgemm(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""DeepGEMM implementation of FP8 GEMM"""
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
def fp8_gemm_sglang(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""SGLang implementation of FP8 GEMM"""
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
# Run SGLang kernel
out = w8a8_block_fp8_matmul(
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
)
return out
def fp8_gemm_vllm(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""vLLM implementation of FP8 GEMM"""
block_size = [128, 128] # Matches the block size in per_block_cast_to_fp8
# Run vLLM kernel
out = vllm_w8a8_block_fp8_matmul(
x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
)
return out
def calculate_diff(m: int, n: int, k: int):
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
out_deepgemm = fp8_gemm_deepgemm(
x_fp8.clone(),
x_scale_col_major.clone(),
y_fp8.clone(),
y_scale.clone(),
m,
n,
k,
)
out_sglang = fp8_gemm_sglang(
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
)
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
out_tilelang = tilelang_kernel(
x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
)
diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()
print(f"Shape m={m}, n={n}, k={k}:")
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
print(f"SGLang output: {out_sglang[0, 0:5]}")
print(f"TileLang output: {out_tilelang[0, 0:5]}")
print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")
sglang_deepgemm_match = torch.allclose(
out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
)
tilelang_deepgemm_match = torch.allclose(
out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
)
tilelang_sglang_match = torch.allclose(
out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
)
if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
print("✅ All implementations match\n")
else:
print("❌ Some implementations differ:")
print(f" - SGLang vs DeepGEMM: {'' if sglang_deepgemm_match else ''}")
print(f" - TileLang vs DeepGEMM: {'' if tilelang_deepgemm_match else ''}")
print(f" - TileLang vs SGLang: {'' if tilelang_sglang_match else ''}\n")
def get_weight_shapes(tp_size):
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def create_benchmark_configs(tp_size):
configs = []
weight_shapes = get_weight_shapes(tp_size)
batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]
for n, k in weight_shapes:
for m in batch_sizes:
configs.append((m, n, k, tp_size))
return configs
def get_benchmark(tp_size):
all_configs = create_benchmark_configs(tp_size)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["m", "n", "k", "tp_size"],
x_vals=[list(config) for config in all_configs],
line_arg="provider",
line_vals=["deepgemm", "sglang", "tilelang"],
line_names=["DeepGEMM", "SGLang", "TileLang"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
args={},
)
)
def benchmark(m, n, k, tp_size, provider):
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Preprocess data before benchmarking
x_fp8, x_scale = per_token_cast_to_fp8(x)
y_fp8, y_scale = per_block_cast_to_fp8(y)
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
quantiles = [0.5, 0.2, 0.8]
if provider == "deepgemm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_deepgemm(
x_fp8.clone(),
x_scale_col_major.clone(),
y_fp8.clone(),
y_scale.clone(),
m,
n,
k,
),
quantiles=quantiles,
)
elif provider == "sglang":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_sglang(
x_fp8.clone(),
x_scale.clone(),
y_fp8.clone(),
y_scale.clone(),
m,
n,
k,
),
quantiles=quantiles,
)
else: # tilelang
tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: tilelang_kernel(
x_fp8.clone(),
x_scale.clone(),
y_fp8.clone(),
y_scale.clone(),
),
quantiles=quantiles,
)
# Calculate TFLOPS
flops = 2 * m * n * k # multiply-adds
tflops = flops / (ms * 1e-3) / 1e12
# Print shape-specific results with TFLOPS
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/fp8_gemm/",
help="Path to save fp8 gemm benchmark results",
)
parser.add_argument(
"--run_correctness",
action="store_true",
default=True,
help="Whether to run correctness test",
)
parser.add_argument(
"--tp_size",
type=int,
default=1,
help="Tensor parallelism size to benchmark (default: 1)",
)
args = parser.parse_args()
# Set random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Run correctness tests on a few examples
if args.run_correctness:
print("Running correctness tests...")
calculate_diff(64, 512, 7168) # Small test
calculate_diff(64, 7168, 16384) # Medium test
calculate_diff(64, 18432, 7168) # Large test
# Get the benchmark function with the specified tp_size
benchmark = get_benchmark(args.tp_size)
print(f"Running performance benchmark for TP size = {args.tp_size}...")
benchmark.run(print_data=True, save_path=args.save_path)

View File

@@ -0,0 +1,486 @@
from typing import Tuple
import deep_gemm
import torch
import triton
import triton.language as tl
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
# Import shared functionality from the regular GEMM benchmark
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
per_block_cast_to_fp8,
per_token_cast_to_fp8,
)
def construct_grouped_and_flat_fp8(
x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool
) -> Tuple[
Tuple[torch.Tensor, torch.Tensor], # grouped x_fp8
Tuple[torch.Tensor, torch.Tensor], # grouped y_fp8
Tuple[torch.Tensor, torch.Tensor], # flat x_fp8
Tuple[torch.Tensor, torch.Tensor], # flat y_fp8
torch.Tensor, # output
torch.Tensor, # reference output
]:
# Verify input shapes
m, k = x.shape
n, k_y = y.shape
assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})"
assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})"
assert m % 4 == 0, f"TMA alignment error: {m}"
# Reshape inputs for grouped processing
m_per_group = m // num_groups
x_grouped = x.view(num_groups, m_per_group, k)
y_grouped = y.unsqueeze(0).expand(num_groups, n, k)
# Initialize output tensors
out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16)
ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped)
# Quantize grouped tensors
x_fp8_grouped = (
torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float
),
)
y_fp8_grouped = (
torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
),
)
for i in range(num_groups):
x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])
y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])
# Quantize flat tensors
x_fp8_flat = per_token_cast_to_fp8(x)
y_fp8_flat = per_block_cast_to_fp8(y)
# For non-masked input, merge the group and M dims in output
if not is_masked:
x_fp8_grouped = (
x_fp8_grouped[0].view(-1, k),
per_token_cast_to_fp8(x_grouped.view(-1, k))[1],
)
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
# Transpose earlier for testing
x_fp8_grouped = (
x_fp8_grouped[0],
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
)
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a
# custom kernel based on the Triton tutorial.
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
@triton.jit
def fp8_gemm_group_triton_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
# Pointers to scaling factors
a_scale_ptr,
b_scale_ptr,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension.
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Strides for scaling factors
stride_a_scale_m,
stride_a_scale_k,
stride_b_scale_n,
stride_b_scale_k,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Note: Block sizes must be multiples of 32 for optimal TMA performance.
"""
# Map program ids to the block of C it should compute
pid_group = tl.program_id(axis=0) # Group ID
pid_n = tl.program_id(axis=1) # N dimension ID
# Compute the M block ID within this group
group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
pid_m_within_group = tl.program_id(axis=2) % group_size_m
pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
# Create pointers for the first blocks of A and B
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# Initialize accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Main loop
for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_offset = k_block * BLOCK_SIZE_K
# Load the next block of A and B, with masks
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)
# Calculate indices for scaling factors for this K block
a_scale_ptrs = a_scale_ptr + (
offs_am * stride_a_scale_m + k_block * stride_a_scale_k
)
b_scale_ptrs = b_scale_ptr + (
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
)
# Perform matrix multiplication in FP8
res = tl.dot(a, b)
# Load scaling factors for the current block
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
b_scale = tl.load(b_scale_ptrs)
# Apply scaling factors to the accumulated result
accumulator += res * a_scale * b_scale
# Advance pointers
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# Convert to bfloat16 for output
c = accumulator.to(tl.bfloat16)
# Write back the result
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
"""
Perform matrix multiplication with FP8 inputs and proper scaling.
Args:
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
c: Output tensor in BF16 format
num_groups: Number of groups for grouped GEMM
Returns:
Result tensor in BF16 format
"""
# Unpack the tuples
a, a_scale = a_tuple
b, b_scale = b_tuple
M, K = a.shape
_, N = b.shape
# Configure block sizes - must be multiples of 32 for TMA alignment
BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 128
# Calculate grid dimensions
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
num_groups_grid = triton.cdiv(num_pid_m, num_groups)
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
fp8_gemm_group_triton_kernel[grid](
a,
b,
c,
a_scale,
b_scale,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
a_scale.stride(0),
1, # Stride in the K dimension may be 1
b_scale.stride(0),
1 if b_scale.dim() > 1 else 0,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=num_groups,
)
return c
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
x_fp8_grouped,
y_fp8_grouped,
out,
m_indices,
)
return out
def calculate_diff(m: int, n: int, k: int, num_groups: int):
print(f"Shape (m={m}, n={n}, k={k}")
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
)
m_per_group = m // num_groups
out_deepgemm = out.clone()
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
m_indices = (
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
)
fp8_gemm_group_deepgemm(
x_fp8_grouped,
y_fp8_grouped,
out_deepgemm,
m_indices,
)
torch.cuda.synchronize()
# Prepare inputs for Triton
a, a_scale = x_fp8_flat
b, b_scale = y_fp8_flat
b = b.T.contiguous()
# Ensure scales are in the right format and contiguous
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
M, _ = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
torch.cuda.synchronize()
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
print(f"Shape m={m}, n={n}, k={k}:")
print(f"Torch output: {out_torch[0, 0:5]}")
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
print(f"Triton output: {out_triton[0, 0:5]}")
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
triton_torch_diff = calc_diff(out_triton, out_torch)
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
DIFF_THRESHOLD = 0.001
all_match = (
deepgemm_torch_diff < DIFF_THRESHOLD
and triton_torch_diff < DIFF_THRESHOLD
and deepgemm_triton_diff < DIFF_THRESHOLD
)
if all_match:
print("✅ All implementations match\n")
else:
print("❌ Some implementations differ:")
print(
f" - Torch vs DeepGEMM: {'' if deepgemm_torch_diff < DIFF_THRESHOLD else ''}"
f" - Torch vs Triton: {'' if triton_torch_diff < DIFF_THRESHOLD else ''}"
f" - DeepGEMM vs Triton: {'' if deepgemm_triton_diff < DIFF_THRESHOLD else ''}"
)
def get_weight_shapes(tp_size):
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def create_benchmark_configs(tp_size):
configs = []
weight_shapes = get_weight_shapes(tp_size)
batch_sizes = [2048, 4096]
group_sizes = [4, 8]
for n, k in weight_shapes:
for m in batch_sizes:
for num_groups in group_sizes:
configs.append((m, n, k, num_groups, tp_size))
return configs
def get_benchmark(tp_size):
all_configs = create_benchmark_configs(tp_size)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["m", "n", "k", "num_groups", "tp_size"],
x_vals=[config for config in all_configs],
line_arg="provider",
line_vals=["deepgemm", "triton"],
line_names=["DeepGEMM", "Triton"],
styles=[("blue", "-"), ("red", "-")],
ylabel="ms",
plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}",
args={},
)
)
def benchmark(m, n, k, num_groups, tp_size, provider):
print(
f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}"
)
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
)
m_per_group = m // num_groups
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
m_indices = (
m_indices.unsqueeze(-1)
.expand(num_groups, m_per_group)
.contiguous()
.view(-1)
)
quantiles = [0.5, 0.2, 0.8]
if provider == "deepgemm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_group_deepgemm(
x_fp8_grouped,
y_fp8_grouped,
out,
m_indices,
),
quantiles=quantiles,
)
elif provider == "triton":
# Prepare inputs for Triton
# We did it outside of the lambda function to make it fair comparison like deepgemm
a, a_scale = x_fp8_flat
b, b_scale = y_fp8_flat
b = b.T.contiguous()
# Ensure scales are in the right format and contiguous
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
M, _ = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_group_triton(
(a, a_scale),
(b, b_scale),
c,
num_groups,
),
quantiles=quantiles,
)
# Calculate TFLOPS
flops = 2 * m * n * k # multiply-adds
tflops = flops / (ms * 1e-3) / 1e12
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/fp8_group_gemm/",
help="Path to save deepgemm fp8 group gemm benchmark results",
)
parser.add_argument(
"--run_correctness",
action="store_true",
help="Whether to run correctness test",
)
parser.add_argument(
"--tp_size",
type=int,
default=1,
help="Tensor parallelism size to benchmark (default: 1)",
)
args = parser.parse_args()
# Set random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Run correctness tests on a few examples
if args.run_correctness:
print("Running correctness tests...")
calculate_diff(8192, 7168, 4096, 4)
calculate_diff(8192, 2048, 7168, 4)
calculate_diff(4096, 7168, 4096, 8)
calculate_diff(4096, 2048, 7168, 8)
calculate_diff(4096, 576, 7168, 8)
# Get the benchmark function with the specified tp_size
benchmark = get_benchmark(args.tp_size)
print(f"Running performance benchmark for TP size = {args.tp_size}...")
benchmark.run(print_data=True, save_path=args.save_path)

View File

@@ -0,0 +1,29 @@
## Benchmark FBGEMM Grouped GEMM
Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations.
### Requirements
```shell
pip install fbgemm-gpu-genai
```
### Usage
```bash
python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
```
For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows:
```shell
grouped-gemm-performance:
batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8
0 256.0 3704.841339 3042.626402 2254.725030
1 512.0 3691.426346 3029.065684 2269.504543
2 1024.0 3653.938629 2258.471467 2358.319020
3 2048.0 3596.644313 2271.611904 2476.895397
4 4096.0 3468.496435 2231.283986 2179.473910
```
The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth.

View File

@@ -0,0 +1,516 @@
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import argparse
import torch
import triton
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
quantize_fp8_row,
triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm as fbgemm_grouped_gemm,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
from transformers import AutoConfig
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)
def get_model_config(model_name: str, tp_size: int):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
num_groups = config.ffn_config.moe_num_experts
intermediate_size = config.ffn_config.ffn_hidden_size
elif config.architectures[0] == "JambaForCausalLM":
num_groups = config.num_experts
intermediate_size = config.intermediate_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
]:
num_groups = config.n_routed_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
num_groups = config.text_config.num_local_experts
intermediate_size = config.text_config.intermediate_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
num_groups = config.num_local_experts
intermediate_size = config.moe_intermediate_size
else:
num_groups = config.num_local_experts
intermediate_size = config.intermediate_size
shape_configs = {
"num_groups": num_groups,
"hidden_size": config.hidden_size,
"intermediate_size": intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
base_weights = torch.randn(
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
)
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
w_sglang = base_weights
c_fbgemm = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
return (
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
)
def create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
):
"""
Create test data for FP8 grouped GEMM operations.
Args:
batch_size: Total batch size
num_groups: Number of groups
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
Returns:
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
For cutlass: (x, wq, w_scale, m_sizes)
"""
torch.manual_seed(42)
tokens_per_group = batch_size // num_groups
# Create weight matrices for each group
w_list = []
for _ in range(num_groups):
w = torch.randn(
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
)
w_list.append(w)
# Quantize weights using quantize_fp8_row for each group
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
if backend == "triton":
# Triton format: concatenated weights
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
# Create m_sizes as int32 for triton
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
)
# Create and quantize input
x_fp16 = torch.randn(
batch_size, hidden_size, dtype=torch.float16, device="cuda"
)
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
x_scale = x_scale.view(batch_size, -1)
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
elif backend == "cutlass":
# CUTLASS format: stacked weights
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
# Create m_sizes as int64 for cutlass
m_values = [tokens_per_group] * num_groups
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
# Create input data - separate for each group then concat
x_list = []
for _ in range(num_groups):
x = torch.randn(
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
)
x_list.append(x)
# Concatenate inputs into single tensor
x = torch.concat(x_list, dim=0).contiguous()
return x, wq, w_scale, m_sizes
else:
raise ValueError(f"Unsupported backend: {backend}")
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
"""
Calculate memory bandwidth based on accessed expert weights.
Args:
m_sizes: Tensor containing batch sizes for each group
hidden_size: Hidden dimension size
intermediate_size: Intermediate dimension size
dtype: Data type of weights
Returns:
Memory size in bytes for accessed expert weights
"""
# Count non-zero groups (active experts)
if hasattr(m_sizes, "cpu"):
active_experts = torch.count_nonzero(m_sizes).item()
else:
active_experts = sum(1 for m in m_sizes if m > 0)
# Calculate bytes per element based on dtype
if dtype in [torch.float16, torch.bfloat16]:
bytes_per_element = 2
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
bytes_per_element = 1
elif dtype == torch.float32:
bytes_per_element = 4
else:
# Default to 2 bytes for unknown dtypes
bytes_per_element = 2
# Memory per expert weight matrix
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
# Total memory for active experts
total_memory_bytes = active_experts * memory_per_expert
return total_memory_bytes
def get_benchmark_config(use_fp8_w8a8=False):
if use_fp8_w8a8:
return {
"line_vals": [
"fbgemm_triton_grouped_gemm_fp8",
"fbgemm_cutlass_f8f8bf16_rowwise",
"sglang_grouped_gemm",
],
"line_names": [
"FBGEMM Triton Grouped GEMM FP8",
"FBGEMM CUTLASS F8F8BF16 Rowwise",
"SGLang Grouped GEMM FP8",
],
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
}
else:
return {
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
"line_names": [
"FBGEMM Triton Grouped GEMM BF16",
"SGLang Grouped GEMM BF16",
],
"styles": [("blue", "-"), ("green", "-")],
}
def run_benchmark(
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
):
config = get_benchmark_config(use_fp8_w8a8)
benchmark_config = triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[256, 512, 1024, 2048, 4096],
line_arg="provider",
line_vals=config["line_vals"],
line_names=config["line_names"],
styles=config["styles"],
ylabel="Bandwidth (GB/s)",
plot_name="grouped-gemm-performance",
args={},
)
@triton.testing.perf_report(benchmark_config)
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"Benchmarking {provider} with batch_size={batch_size}")
torch.cuda.manual_seed_all(0)
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
if provider == "fbgemm_triton_grouped_gemm_fp8":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="triton",
)
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
def run_func():
return fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)
except Exception as e:
print(f"FP8 not supported, skipping: {e}")
return float("inf"), float("inf"), float("inf")
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
try:
test_data = create_fp8_test_data(
batch_size,
num_groups,
hidden_size,
intermediate_size,
backend="cutlass",
)
x, wq, w_scale, m_sizes = test_data
# Calculate memory bandwidth
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
)
# Quantize input using triton_quantize_fp8_row
xq, x_scale = triton_quantize_fp8_row(x)
x_scale = x_scale.view(batch_size, -1)
def run_func():
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
xq, wq, x_scale, w_scale, m_sizes
)
except Exception as e:
print(
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
f"skipping: {e}"
)
return float("inf"), float("inf"), float("inf")
else:
test_data = create_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
# Calculate memory bandwidth for BF16 operations
memory_bytes = calculate_memory_bandwidth(
m_sizes, hidden_size, intermediate_size, torch.bfloat16
)
if provider == "fbgemm_triton_grouped_gemm":
def run_func():
return fbgemm_grouped_gemm(
x, w_fbgemm, m_sizes, use_fast_accum=True
)
else:
def run_func():
return sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
for _ in range(10):
try:
run_func()
except Exception as e:
print(f"Error during warmup for {provider}: {e}")
return float("inf"), float("inf"), float("inf")
torch.cuda.synchronize()
try:
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
# Convert time (ms) to bandwidth (GB/s)
# Bandwidth = Memory (bytes) / Time (seconds)
# Convert ms to seconds and bytes to GB (1e9)
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
# min bandwidth = max time, max bandwidth = min time
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
return gb_per_s, min_gb_per_s, max_gb_per_s
except Exception as e:
print(f"Error during benchmarking for {provider}: {e}")
return 0.0, 0.0, 0.0
dynamic_benchmark.run(
show_plots=True,
print_data=True,
save_path=save_path,
model_config=model_config,
use_fp8_w8a8=use_fp8_w8a8,
)
def verify_correctness(model_config):
print("Verifying correctness...")
batch_size = 128
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False
return True
def main():
parser = argparse.ArgumentParser(
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="Model name to get configuration from",
)
parser.add_argument(
"--tp-size", type=int, default=1, help="Tensor parallelism size"
)
parser.add_argument(
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
)
parser.add_argument(
"--save-path",
type=str,
default="./benchmark_grouped_gemm/",
help="Path to save benchmark results",
)
parser.add_argument(
"--verify-correctness",
action="store_true",
help="Verify correctness before benchmarking",
)
args = parser.parse_args()
try:
model_config = get_model_config(args.model, args.tp_size)
except Exception as e:
print(f"Failed to get model config: {e}")
print("Using default configuration...")
model_config = {
"num_groups": 8,
"hidden_size": 4096,
"intermediate_size": 14336,
"dtype": torch.bfloat16,
}
print("Running benchmark with:")
print(f" num_groups: {model_config['num_groups']}")
print(f" hidden_size: {model_config['hidden_size']}")
print(f" intermediate_size: {model_config['intermediate_size']}")
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
if args.verify_correctness:
if not verify_correctness(model_config):
print("Correctness verification failed. Exiting...")
return
try:
run_benchmark(
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
save_path=args.save_path,
)
except Exception as e:
print(f"Benchmark failed: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,102 @@
# FlashInfer Fused AllReduce + RMSNorm Benchmark
This benchmark script is modified from the [original implementation](https://github.com/vllm-project/vllm/blob/237e1fb887c7f5a579420fa0295097f24b006594/benchmarks/kernels/benchmark_fused_collective.py) by the vLLM community. It aims to compare the performance differences between FlashInfer fused operators in SGLang (trtllm_allreduce_fusion: AllReduce + Residual Add + RMSNorm + optional quantization) and conventional implementations (standard `tensor_model_parallel_all_reduce` + separate RMSNorm/quantization). Specifically, this script tests the timing performance of two implementation paths: 1) Standard AllReduce and RMSNorm executed separately; 2) FlashInfer's fused operator combining AllReduce, Residual Add, RMSNorm, and optional quantization operations.
This benchmark script helps us tune the ipc workspace size of the `flashinfer_allreduce_residual_rmsnorm` operator in SGLang and prepare for applications with FP8/FP4 quantized fused operators.
Script path: `benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py`
## Feature Overview
- Compare average execution time (ms) and calculate speedup ratios for the following paths:
- standard_allreduce_rmsnorm (Standard AllReduce + RMSNorm)
- flashinfer_fused_allreduce_rmsnorm (Fused AllReduce + RMSNorm), including oneshot and twoshot modes
- Optionally compare FP8/FP4 quantized fused paths with standard paths
- Use CUDA Graph capture and batch replay to reduce measurement noise
- Automatically select the faster "standard baseline" (native/compiled version) as the denominator for speedup calculation
- Optionally export results in Markdown format
## Runtime Environment and Prerequisites
- At least 2 GPUs, and launch multi-process distributed training using `torchrun` (NCCL backend)
- Properly install/compile sglang along with sgl-kernel and custom operators
## Quick Start (Command Examples)
The following examples use world_size=2. You can modify `--nproc_per_node` and parameters according to your machine:
- Regular paths only (no quantization):
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--no-quant --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- FP8 quantization paths only:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--quant-fp8 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- FP4 quantization paths only:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--quant-fp4 --hidden-dim 1024 --seq-lens 512 1024 2048 4096 --trials 100
```
- Larger hidden dimensions:
```
torchrun --nproc_per_node=2 \
benchmark/kernels/flashinfer_allreduce_fusion/benchmark_fused_collective.py \
--no-quant --hidden-dim 4096 --seq-lens 512 1024 2048 4096 --trials 100
```
## Parameter Description
- `--seq-lens`: List of sequence lengths to test (default: 128 512 1024 2048)
- `--hidden-dim`: Hidden dimension (default: 8192)
- `--dtypes`: Data type list, `float16|bfloat16|float32` (default: bfloat16)
- `--no-residual`: Only test "no residual" scenarios (default tests both "with/without residual")
- Mutually exclusive quantization options:
- `--no-quant`: No quantization testing
- `--quant-fp8`: Only FP8 quantization testing
- `--quant-fp4`: Only FP4 quantization testing
- `--quant-all`: Test all (default)
- FlashInfer related:
- `--disable-oneshot`: Disable oneshot mode (default enables oneshot and tests twoshot simultaneously)
- Runtime configuration:
- `--warmup`: Warmup count before graph capture and before graph replay (default 5)
- `--trials`: Benchmark iteration count (default 20; internally each `graph.replay()` will batch replay multiple times)
- `--output-file`: Save results as Markdown file (only rank0 takes effect)
## Output Example
Each configuration group prints a table showing average execution time and relative speedup ratios (baseline is the faster standard implementation). For example:
```
================================================================================
Results: seq_len=1024, hidden_dim=1024
dtype=torch.bfloat16, residual=yes, quant_mode=none
================================================================================
Operation Time (ms) Speedup
--------------------------------------------------------------------------------
standard_allreduce_rmsnorm 0.024 0.98x
standard_allreduce_rmsnorm_native_compiled 0.023 baseline
flashinfer_fused_allreduce_rmsnorm_oneshot 0.011 2.19x
flashinfer_fused_allreduce_rmsnorm_twoshot 0.041 0.57x
```
If `--output-file` is specified, all configurations will be summarized in Markdown tables in that file.
## Important Notes and Recommendations
- Distributed: The script uses `torchrun` environment variables to initialize distributed training and binds tensors/communication groups to the current rank's corresponding device.
- World size: Requires `WORLD_SIZE > 1` to perform communication operator benchmarks. Otherwise, the script will error and prompt.
- FlashInfer:
- If not installed or interfaces are missing, the script will only run standard paths and provide prompts in the logs.
- The fused operator internally uses "oneshot"/"twoshot" two trigger methods; oneshot is enabled by default and twoshot is tested simultaneously.
- FP8/FP4:
- FP8 uses sglang's FP8 tools and dtype, with underlying platform selection of `e4m3`/`e4m3fnuz` etc.
- FP4 uses sgl-kernel's `scaled_fp4_quant`, requiring corresponding platform support.
- CUDA Graph:
- Uses sglang's `graph_capture()` to prepare capture-ready state for communication, then uses `torch.cuda.graph` to capture kernels, reducing measurement jitter.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,76 @@
## Tuning Triton MoE Kernels
This directory contains benchmarking tools for MoE (Mixture of Experts) kernels.
### Tuning Tool
- `tuning_fused_moe_triton.py`: A tool for tuning the `fused_moe_triton` kernel. Adapted from [vllm's benchmark_moe.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), with added support for various model architectures.
Example usage:
```bash
# Tune Mixtral-8x7B with default settings
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tune
# Tune Qwen2-57B with FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--tp-size 4 \
--dtype fp8_w8a8 \
--tune
# Tune Qwen3-235B-A22B-FP8 and TP=4
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model Qwen/Qwen3-235B-A22B-FP8 \
--tp-size 4 \
--dtype fp8_w8a8 \
--tune
# Tune DeepSeek-V3 with FP8 and TP=8
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8 \
--dtype fp8_w8a8 \
--tune
# Tune DeepSeek-R1 with channel-wise INT8 and TP=16
python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \
--model meituan/DeepSeek-R1-Channel-INT8 \
--tp-size 16 \
--dtype int8_w8a8 \
--tune
```
After tuning, a configuration file (e.g., `E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json`) will be generated in the current directory. You can move this file to `sglang/srt/layers/fused_moe_triton/configs/triton_version` dir to use it in `sglang`.
### Performance Comparison Tool
- `benchmark_vllm_vs_sglang_fused_moe_triton.py`: A tool for comparing the performance of fused MoE kernels between vllm and sglang implementations. Supports various model architectures and data types.
Example usage:
```bash
# Compare with default settings (Mixtral model)
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py
# Compare with FP8 mode for Qwen2-57B
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model Qwen/Qwen2-57B-A14B-Instruct \
--use-fp8-w8a8
# Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8
# Compare with custom TP size
python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \
--model deepseek-ai/DeepSeek-V3-0324 \
--tp-size 8
```
The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`).
- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel.
Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`, note that `torch.compile` does not support `fp8_w8a8` and `int8_w8a8` fused_moe_kernel.

View File

@@ -0,0 +1,292 @@
# python3 benchmark/kernels/fused_moe_triton/sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8
import argparse
import torch
import triton
from transformers import AutoConfig
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_triton_api(
x,
w1,
w2,
input_gating,
topk,
):
topk_op = TopK(
top_k=topk,
renormalize=False,
use_grouped_topk=False,
)
topk_op.use_triton_kernels = True
triton_topk_output = topk_op.forward_cuda(
hidden_states=x,
router_logits=input_gating,
)
moe_runner_config = MoeRunnerConfig(
inplace=False,
)
return triton_kernel_moe_forward(
x,
w1,
w2,
triton_topk_output,
moe_runner_config,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
topk_output = select_experts(
hidden_states=x,
router_logits=input_gating,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
return fused_moe_sglang(
x,
w1,
w2,
topk_output,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list([128, 256, 512, 1024, 2048, 4096, 8192]),
line_arg="provider",
line_vals=[
"sglang_fused_moe_triton_v340",
"sglang_fused_moe_triton",
],
line_names=[
"sglang_fused_moe_triton_v340",
"sglang_fused_moe_triton",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(
batch_size,
provider,
model_config,
use_fp8_w8a8=False,
use_cuda_graph: bool = False,
):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_tri = w1.clone()
w2_tri = w2.clone()
w1_tri = w1_tri.transpose(-2, -1).contiguous()
w2_tri = w2_tri.transpose(-2, -1).contiguous()
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
if provider == "sglang_fused_moe_triton_v340":
api_func = fused_moe_triton_api
api_kwargs = {
"x": x,
"w1": w1_tri,
"w2": w2_tri,
"input_gating": input_gating,
"topk": topk,
}
else:
api_func = fused_moe_sglang_api
api_kwargs = {
"x": x,
"w1": w1,
"w2": w2,
"input_gating": input_gating,
"topk": topk,
"use_fp8_w8a8": use_fp8_w8a8,
"block_shape": block_shape,
}
# Warmup
for _ in range(10):
_ = api_func(**api_kwargs)
torch.cuda.synchronize()
if use_cuda_graph:
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
api_func(**api_kwargs)
torch.cuda.synchronize()
bench_lambda = lambda: graph.replay()
else:
bench_lambda = lambda: api_func(**api_kwargs)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(bench_lambda, quantiles=quantiles)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--use-cuda-graph", action="store_true", help="Enable CUDA Graph capture/replay"
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/sglang_fused_moe/",
)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
use_cuda_graph=args.use_cuda_graph,
)
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,199 @@
import torch
import triton
import triton.language as tl
from triton.testing import do_bench
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 1
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def compute_sum_scaled_baseline(
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
return out
@torch.compile
def compute_sum_scaled_compiled(
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
torch.sum(x * routed_scaling_factor, dim=1, out=out)
return out
def get_benchmark():
num_tokens_range = [2**i for i in range(0, 13)]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=num_tokens_range,
line_arg="version",
line_vals=["baseline", "compiled", "triton"],
line_names=["Original", "TorchCompile", "TritonKernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="sum_scaled_performance",
args={},
)
)
def benchmark(num_tokens, version):
topk = 9
hidden_size = 4096
dtype = torch.bfloat16
scaling_factor = 0.3
x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
# Warmup
for _ in range(3):
if version == "baseline":
compute_sum_scaled_baseline(x, out, scaling_factor)
elif version == "compiled":
compute_sum_scaled_compiled(x, out, scaling_factor)
else:
moe_sum_reduce(x, out, scaling_factor)
# Benchmark
quantiles = [0.5, 0.2, 0.8]
if version == "baseline":
ms, min_ms, max_ms = do_bench(
lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
quantiles=quantiles,
)
elif version == "compiled":
ms, min_ms, max_ms = do_bench(
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = do_bench(
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def verify_correctness(num_tokens=1024):
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
scaling_factor = 0.3
out_baseline = torch.empty_like(x[:, 0])
compute_sum_scaled_baseline(x, out_baseline, scaling_factor)
out_compiled = torch.empty_like(out_baseline)
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
out_triton = torch.empty_like(out_baseline)
moe_sum_reduce(x, out_triton, scaling_factor)
if torch.allclose(
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
print(
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
)
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
if __name__ == "__main__":
print("Running correctness verification...")
verify_correctness()
print("\nRunning performance benchmark...")
benchmark = get_benchmark()
benchmark.run(
print_data=True,
# save_path="./configs/benchmark_ops/sum_scaled/"
)

View File

@@ -0,0 +1,305 @@
# python3 benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse
import torch
import triton
from torch.nn import functional as F
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_triton,
)
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs
def fused_topk_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
@torch.compile(dynamic=False)
def fused_moe_torch(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
) -> torch.Tensor:
assert not use_fp8_w8a8, "Fp8_w8a8 fused_moe is not supported for torch compile"
topk_weights, topk_ids = fused_topk_native(
hidden_states=x,
gating_output=input_gating,
topk=topk,
renormalize=True,
)
w13_weights = w1[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = w2[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
def fused_moe_torch_compile(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_torch(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
):
return fused_moe_triton(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list(range(1, 5)),
line_arg="provider",
line_vals=[
"fused_moe_triton",
"fused_moe_torch_compile",
],
line_names=[
"fused_moe_triton",
"fused_moe_torch_compile",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
set_torch_compile_config()
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
w1_scale = w2_scale = a1_scale = a2_scale = None
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
# Warmup
api_func = (
fused_moe_torch_compile
if provider == "fused_moe_torch_compile"
else fused_moe_sglang_api
)
for _ in range(10):
y = api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)[0],
quantiles=quantiles,
)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/fused_moe_torch_compile/",
)
args = parser.parse_args()
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,349 @@
# python3 benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py --model /DeepSeek-V3/ --tp-size 8 --use-fp8-w8a8
import argparse
import torch
import triton
import vllm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe as fused_moe_sglang,
)
def get_model_config(model_name: str, tp_size: int):
"""Get model configuration parameters"""
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
vllm_version_num = (
vllm.__version_tuple__[0] * 100
+ vllm.__version_tuple__[1] * 10
+ vllm.__version_tuple__[2]
)
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
assert (
vllm_version_num >= 66
), "Block-wise quantized fp8 fused_moe is only supported for VLLM>=0.6.6.post1"
shape_configs = {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": block_shape,
}
print(f"{shape_configs=}")
return shape_configs
def fused_moe_vllm_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
if block_shape is not None:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
else:
return fused_moe_vllm(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_moe_sglang_api(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=False,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
):
return fused_moe_sglang(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=list(range(1, 513)),
line_arg="provider",
line_vals=[
"vllm_fused_moe_triton",
"sglang_fused_moe_triton",
],
line_names=[
"vllm_fused_moe_triton",
"sglang_fused_moe_triton",
],
styles=[
("blue", "-"),
("green", "-"),
],
ylabel="Time (ms)",
plot_name="fused-moe-performance",
args={},
)
)
def benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"benchmark {provider} with batch_size={batch_size}")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_tokens = batch_size
num_experts = model_config["num_experts"]
hidden_size = model_config["hidden_size"]
shard_intermediate_size = model_config["shard_intermediate_size"]
topk = model_config["topk"]
dtype = model_config["dtype"]
block_shape = model_config["block_shape"]
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1_scale = w2_scale = a1_scale = a2_scale = None
if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
if block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
else:
w1 = torch.randn(num_experts, shard_intermediate_size, hidden_size, dtype=dtype)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=dtype
)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
# Warmup
api_func = (
fused_moe_vllm_api
if provider == "vllm_fused_moe_triton"
else fused_moe_sglang_api
)
for _ in range(10):
y = api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: api_func(
x,
w1,
w2,
input_gating,
topk,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)[0],
quantiles=quantiles,
)
return ms, min_ms, max_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--use-fp8-w8a8", action="store_true")
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/vllm_sglang_fused_moe/",
)
args = parser.parse_args()
try:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="tcp://127.0.0.1:23456",
world_size=1,
rank=0,
)
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method="tcp://127.0.0.1:23456",
local_rank=0,
backend="nccl" if torch.cuda.is_available() else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
model_config = get_model_config(args.model, args.tp_size)
benchmark.run(
show_plots=True,
print_data=True,
save_path=args.save_path,
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
)
finally:
destroy_model_parallel()
destroy_distributed_environment()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,599 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py
import argparse
import json
import time
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Dict, List, Tuple, TypedDict
import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
get_config_file_name,
get_default_config,
get_moe_configs,
)
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip
_is_hip = is_hip()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int] = None,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16 or use_int8_w8a8:
w1 = torch.randint(
-127,
127,
(
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8,
)
w2 = torch.randint(
-127,
127,
(
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8,
)
else:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8 or use_int8_w8a8:
if use_int8_w8a8 and block_shape is None:
w1_scale = torch.randn(
num_experts, shard_intermediate_size, dtype=torch.float32
)
w2_scale = torch.randn(num_experts, hidden_size, dtype=torch.float32)
elif block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n
n_tiles_w2 = (hidden_size + block_n - 1) // block_n
k_tiles_w1 = (hidden_size + block_k - 1) // block_k
k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k
w1_scale = torch.rand(
(num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32
)
w2_scale = torch.rand(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_config = TopKConfig(
top_k=topk,
renormalize=True,
)
topk_output = select_experts(x, input_gating, topk_config)
def prepare(i: int):
input_gating = gating_output[i]
new_topk_output = select_experts(x, input_gating, topk_config)
topk_output.topk_weights.copy_(new_topk_output.topk_weights)
topk_output.topk_ids.copy_(new_topk_output.topk_ids)
topk_output.router_logits.copy_(new_topk_output.router_logits)
def run():
moe_runner_config = MoeRunnerConfig(
inplace=True,
)
with override_config(config):
fused_moe(
x,
w1,
w2,
topk_output,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
)
# JIT compilation & warmup
run()
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [1, 2, 4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(0)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
)
if op_config is None:
config = get_default_config(
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype_str,
False,
block_shape,
)
else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
return config, kernel_time
def tune(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
with torch.cuda.device(self.device_id) if is_hip() else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
num_iters=10,
)
except (triton.runtime.autotuner.OutOfResources, RuntimeError):
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
return best_config
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts,
shard_intermediate_size // 2,
dtype_str,
block_shape,
)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
if args.tune:
search_space = get_configs_compute_bound()
if block_shape is not None:
block_n, block_k = block_shape[0], block_shape[1]
search_space = [
config
for config in search_space
if block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")
start = time.perf_counter()
configs = _distribute(
"tune",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
search_space,
)
for batch_size in batch_sizes
],
)
best_configs = {
M: sort_config(config) for M, config in zip(batch_sizes, configs)
}
save_configs(
best_configs,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
end = time.perf_counter()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
for batch_size in batch_sizes
],
)
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
print(f"Kernel time: {kernel_time:.2f} us")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument(
"--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
default="auto",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--disable-shared-experts-fusion", action="store_true")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,576 @@
import itertools
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
@triton.jit
def _decode_kernel(
Q,
K,
V,
KV,
Out,
S,
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
d_original: tl.constexpr,
e: tl.constexpr,
e_original: tl.constexpr,
):
off_bh = tl.program_id(0)
off_h = off_bh % h
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
kv_offset = off_bh * d * e
s = tl.load(S + off_h)
ratio = tl.exp(-s)
d_idx = tl.arange(0, d)
e_idx = tl.arange(0, e)
# Create masks for original dimensions
d_mask = d_idx < d_original
e_mask = e_idx < e_original
# Load with masking
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
# Load KV with 2D masking
kv = tl.load(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
mask=(d_mask[:, None] & e_mask[None, :]),
other=0.0,
)
# Compute outer product using element-wise operations
k_v_prod = k[:, None] * v[None, :]
kv = ratio * kv + k_v_prod
# Store KV with 2D masking
tl.store(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
kv.to(KV.dtype.element_ty),
mask=(d_mask[:, None] & e_mask[None, :]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o = tl.sum(q[:, None] * kv, axis=0)
# Store output with masking
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
def lightning_attn_decode(q, k, v, kv, s):
"""Triton implementation of Lightning Attention decode operation"""
b, h, n, d = q.shape
e = v.shape[-1]
assert n == 1, "Sequence length must be 1 in decode mode"
# Get padded dimensions (power of 2)
d_padded = next_power_of_2(d)
e_padded = next_power_of_2(e)
# Create output tensor (padded)
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
# Create padded tensors without actually padding the data
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
kv_padded = torch.empty(
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
)
# Copy data to padded tensors
q_padded[..., :d] = q
k_padded[..., :d] = k
v_padded[..., :e] = v
kv_padded[..., :d, :e] = kv
# Launch kernel
grid = (b * h, 1)
_decode_kernel[grid](
q_padded,
k_padded,
v_padded,
kv_padded,
o_padded,
s,
b=b,
h=h,
n=n,
d=d_padded,
d_original=d,
e=e_padded,
e_original=e,
)
# Get unpadded outputs
o = o_padded[..., :e]
kv_out = kv_padded[..., :d, :e]
return o, kv_out
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
class MiniMaxText01LightningAttention(nn.Module):
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
super().__init__()
if config is None:
config = type("Config", (), kwargs)
bias = False
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.out_proj = nn.Linear(
self.head_dim * self.num_heads, self.hidden_size, bias=bias
)
self.act = get_activation_fn(config.hidden_act)
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
self.qkv_proj = nn.Linear(
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
)
self.output_gate = nn.Linear(
self.hidden_size, self.head_dim * self.num_heads, bias=bias
)
# for inference only
self.offset = 0
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None,
**kwargs,
):
if (not self.training) and (not do_eval):
return self.inference(
hidden_states,
attn_mask,
output_attentions,
past_key_value,
use_cache,
slope_rate,
)
def inference(
self,
x,
attn_mask: Optional[torch.Tensor] = None, # (b, n)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
):
# x: b n d
b, n, d = x.shape
# linear map
qkv = self.act(self.qkv_proj(x))
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
q = q.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
k = k.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d]
v = v.transpose(1, 2) # [b, n, h, d] -> [b, h, n, e]
self.offset += 1
ratio = torch.exp(-slope_rate) # [h, 1, 1]
# decode mode
kv = past_key_value # [b, h, d, e]
output = []
for i in range(n):
# kv: [b, h, d, e]
# ratio: [h, 1, 1]
# k: [b, h, n, d]
# v: [b, h, n, e]
# k[:, :, i : i + 1]: [b, h, 1, d]
# v[:, :, i : i + 1]: [b, h, 1, e]
# ratio * kv: [b, h, d, e]
# torch.einsum(
# "... n d, ... n e -> ... d e",
# k[:, :, i : i + 1],
# v[:, :, i : i + 1],
# )
# [b, h, d, e] + [b, h, d, e] -> [b, h, d, e]
kv = ratio * kv + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
# q[:, :, i : i + 1]: [b, h, 1, d]
# kv.to(q.dtype): [b, h, d, e]
# torch.einsum(
# "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
# )
# [b, h, 1, d] * [b, h, d, e] -> [b, h, 1, e]
qkv = torch.einsum(
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
# normalize
output = self.norm(output)
# gate
output = F.sigmoid(self.output_gate(x)) * output
# outproj
output = self.out_proj(output)
attn_weights = None
return output, attn_weights, kv
def get_activation_fn(activation):
if activation == "gelu":
return F.gelu
elif activation == "relu":
return F.relu
elif activation == "elu":
return F.elu
elif activation == "sigmoid":
return F.sigmoid
elif activation == "exp":
def f(x):
with torch.no_grad():
x_max = torch.max(x, dim=-1, keepdims=True).values
y = torch.exp(x - x_max)
return y
return f
elif activation == "leak":
return F.leaky_relu
elif activation == "1+elu":
def f(x):
return 1 + F.elu(x)
return f
elif activation == "2+elu":
def f(x):
return 2 + F.elu(x)
return f
elif activation == "silu" or activation == "swish":
return F.silu
elif activation == "sine":
return torch.sin
else:
return lambda x: x
class MiniMaxText01RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def test_lightning_attention_implementations(model_params):
torch.manual_seed(42)
batch_size = 64
seq_len = 1
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_states = torch.randn(
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
model_attn.eval()
d = model_params["head_dim"]
past_kv = torch.randn(
batch_size,
model_params["num_attention_heads"],
d,
d,
device=device,
)
with torch.no_grad():
model_output, _, new_kv = model_attn.inference(
hidden_states,
attn_mask=attention_mask,
slope_rate=slope_rate,
past_key_value=past_kv,
)
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
past_kv = past_kv.contiguous()
slope_rate = slope_rate.contiguous()
# Test Triton implementation
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
triton_output = triton_output.transpose(1, 2).contiguous()
triton_output = triton_output.view(batch_size, seq_len, -1)
triton_output = model_attn.norm(triton_output)
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
triton_output = model_attn.out_proj(triton_output)
# Test SGL implementation
sgl_output = torch.empty_like(v)
sgl_new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
sgl_output = sgl_output.transpose(1, 2).contiguous()
sgl_output = sgl_output.view(batch_size, seq_len, -1)
sgl_output = model_attn.norm(sgl_output)
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
sgl_output = model_attn.out_proj(sgl_output)
# Verify Triton implementation results
torch.testing.assert_close(
model_output,
triton_output,
rtol=1e-3,
atol=1e-2,
msg="Triton lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
triton_new_kv,
rtol=1e-3,
atol=1e-2,
msg="Triton lightning attention implementation produces different kv results",
)
# Verify SGL implementation results
torch.testing.assert_close(
model_output,
sgl_output,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
sgl_new_kv,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different kv results",
)
print("✅ All implementations match")
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
n_attention_heads, 1, 1
)
return slopes
def get_benchmark():
batch_size_range = [i for i in range(1, 33)] # max 32
seq_length_range = [1] # decode mode sequence length is fixed to 1
configs = list(itertools.product(batch_size_range, seq_length_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["Original", "Triton", "SGL"],
line_names=[
"Original PyTorch Implementation",
"Triton Implementation",
"SGL Implementation",
],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "gelu",
}
hidden_states = torch.randn(
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
model_attn.eval()
d = params["head_dim"]
past_kv = torch.randn(
batch_size,
params["num_attention_heads"],
d,
d,
device=device,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "Original":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: model_attn.inference(
hidden_states,
attn_mask=attention_mask,
slope_rate=slope_rate,
past_key_value=past_kv,
),
quantiles=quantiles,
)
elif provider == "Triton":
def run_triton():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
output, new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_triton,
quantiles=quantiles,
)
else: # SGL
def run_sgl():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
output = torch.empty_like(v)
new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(
q, k, v, past_kv, slope_rate, output, new_kv
)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_sgl,
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_decode/",
help="Path to save lightning attention decode benchmark results",
)
args = parser.parse_args()
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "silu",
}
# Run correctness test first
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
test_lightning_attention_implementations(params)
# Run performance benchmark
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=args.save_path)

View File

@@ -0,0 +1,603 @@
import itertools
import math
import os
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
@triton.jit
def _fwd_kernel(
Q,
K,
V,
Out,
S, # log lambda
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
e: tl.constexpr,
BLOCK: tl.constexpr,
NUM_BLOCK: tl.constexpr,
BLOCK_MODEL: tl.constexpr,
):
##### get offset
off_bh = tl.program_id(0)
off_h = off_bh % h
off_e = tl.program_id(1)
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
# channel offset
e_offset = off_e * BLOCK_MODEL
##### get block ptr
Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :]
K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]
V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
S_block_ptr = S + off_h
##### init diag decay(Lambda); q, k decay; kv
s = tl.load(S_block_ptr)
# q, k decay
off_block = tl.arange(
0, BLOCK
) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])
k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :]))
block_decay = tl.exp(-s.to(tl.float32) * BLOCK)
# diag decay
index = off_block[:, None] - off_block[None, :]
s_index = s * index
s_index = tl.where(index >= 0, -s_index, float("-inf"))
diag_decay = tl.exp(s_index)
kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)
##### compute
for i in range(NUM_BLOCK):
# load
q = tl.load(
Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0
).to(tl.float32)
k_trans = tl.load(
K_trans_block_ptr + off_block[None, :] * d,
mask=off_block[None, :] < n,
other=0.0,
).to(tl.float32)
v = tl.load(
V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0
).to(tl.float32)
# compute
qk = tl.dot(q, k_trans) * diag_decay
o_intra = tl.dot(qk, v)
o_inter = tl.dot(q, kv) * q_decay
o = o_intra + o_inter
# save and update
tl.store(
O_block_ptr + off_block[:, None] * e,
o.to(O_block_ptr.dtype.element_ty),
mask=off_block[:, None] < n,
)
kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)
off_block += BLOCK
def lightning_attn2(q, k, v, s):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
s = s.contiguous()
b, h, n, d = q.shape
e = v.shape[-1]
# Pad d to next power of 2
d_padded = next_power_of_2(d)
if d_padded != d:
q_padded = F.pad(q, (0, d_padded - d))
k_padded = F.pad(k, (0, d_padded - d))
else:
q_padded = q
k_padded = k
# Pad e to next power of 2
e_padded = next_power_of_2(e)
if e_padded != e:
v_padded = F.pad(v, (0, e_padded - e))
else:
v_padded = v
o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device)
BLOCK = 64
NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)
# parallel over channel
BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32)
grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL))
_fwd_kernel[grid](
q_padded,
k_padded,
v_padded,
o_padded,
s,
b,
h,
n,
d_padded,
e_padded,
BLOCK=BLOCK,
NUM_BLOCK=NUM_BLOCK,
BLOCK_MODEL=BLOCK_MODEL,
)
# Remove padding from output
if e_padded != e:
o = o_padded[..., :e]
else:
o = o_padded
return o
def is_support(dim):
return 16 % dim
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
def lightning_attn_func(q, k, v, s):
b, h, n, d = q.shape
e = v.shape[-1]
assert is_support(d) and is_support(e)
# pad v's feature dim to power of 2
e_pad = next_power_of_2(e)
need_pad = e_pad != e
if need_pad:
v = F.pad(v, (0, e_pad - e))
if d > 128:
# split over head
if 64 % d:
m = 64
elif 32 % d:
m = 32
elif 16 % d:
m = 16
arr = [m * i for i in range(d // m + 1)]
if arr[-1] != d:
arr.append(d)
n = len(arr)
o = 0
for i in range(n - 1):
start = arr[i]
end = arr[i + 1]
q1 = q[..., start:end]
k1 = k[..., start:end]
o += lightning_attn2(q1, k1, v, s)
else:
o = lightning_attn2(q, k, v, s)
if need_pad:
o = o[:, :, :, :e]
return o
debug = eval(os.environ.get("debug", default="False"))
BLOCK = 256
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
class MiniMaxText01RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniMaxText01RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
def get_activation_fn(activation):
if debug:
logger.info(f"activation: {activation}")
if activation == "gelu":
return F.gelu
elif activation == "relu":
return F.relu
elif activation == "elu":
return F.elu
elif activation == "sigmoid":
return F.sigmoid
elif activation == "exp":
def f(x):
with torch.no_grad():
x_max = torch.max(x, dim=-1, keepdims=True).values
y = torch.exp(x - x_max)
return y
return f
elif activation == "leak":
return F.leaky_relu
elif activation == "1+elu":
def f(x):
return 1 + F.elu(x)
return f
elif activation == "2+elu":
def f(x):
return 2 + F.elu(x)
return f
elif activation == "silu" or activation == "swish":
return F.silu
elif activation == "sine":
return torch.sin
else:
logger.info(f"activation: does not support {activation}, use Identity!!!")
return lambda x: x
# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
class MiniMaxText01LightningAttention(nn.Module):
def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
super().__init__()
if config is None:
config = type("Config", (), kwargs)
bias = False
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.out_proj = nn.Linear(
self.head_dim * self.num_heads, self.hidden_size, bias=bias
)
self.act = get_activation_fn(config.hidden_act)
self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)
self.qkv_proj = nn.Linear(
self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
)
self.output_gate = nn.Linear(
self.hidden_size, self.head_dim * self.num_heads, bias=bias
)
# for inference only
self.offset = 0
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None,
**kwargs,
):
if (not self.training) and (not do_eval):
return self.inference(
hidden_states,
attn_mask,
output_attentions,
past_key_value,
use_cache,
slope_rate,
)
def inference(
self,
x,
attn_mask: Optional[torch.Tensor] = None, # (b, n)
output_attentions: bool = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
):
# x: b n d
b, n, d = x.shape
# linear map
qkv = self.act(self.qkv_proj(x))
new_shape = qkv.size()[:-1] + (self.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if past_key_value is None:
self.offset = q.shape[-2]
else:
self.offset += 1
# for align with metaseq
ratio = torch.exp(-slope_rate)
# only use for the first time
if past_key_value is None:
slope_rate = slope_rate.to(torch.float32)
if attn_mask is not None:
v = v.masked_fill(
(1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
)
NUM_BLOCK = (n + BLOCK - 1) // BLOCK
b, h, n, d = q.shape
e = v.shape[-1]
# other
array = torch.arange(BLOCK).to(q) + 1
q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
index = array[:, None] - array[None, :]
s_index = (
slope_rate
* index[
None,
None,
]
)
s_index = torch.where(index >= 0, -s_index, float("-inf"))
diag_decay = torch.exp(s_index)
kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
for i in range(NUM_BLOCK):
si = i * BLOCK
ei = min(si + BLOCK, n)
m = ei - si
qi = q[:, :, si:ei].contiguous()
ki = k[:, :, si:ei].contiguous()
vi = v[:, :, si:ei].contiguous()
qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)
# diag
qk = (
torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32)
* diag_decay[:, :, :m, :m]
)
qkv_diag = torch.matmul(qk, vi.to(torch.float32))
block_decay = torch.exp(-slope_rate * m)
output[:, :, si:ei] = qkv_none_diag + qkv_diag
kv = block_decay * kv + torch.matmul(
(ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi
)
else:
kv = past_key_value
output = []
for i in range(n):
kv = ratio * kv + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
qkv = torch.einsum(
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
# normalize
output = self.norm(output)
# gate
output = F.sigmoid(self.output_gate(x)) * output
# outproj
output = self.out_proj(output)
attn_weights = None
return output, attn_weights, kv
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(
n
) # In the paper, we only train models that have 2^a heads for some a. This function has
else: # some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2 = 2 ** math.floor(
math.log2(n)
) # when the number of heads is not a power of 2, we use this workaround.
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
# h, 1, 1
slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
n_attention_heads, 1, 1
)
return slopes
def test_lightning_attention_implementations(model_params):
torch.manual_seed(42)
batch_size = 2
seq_len = 1024
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_states = torch.randn(
batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
model_attn.eval()
with torch.no_grad():
model_output, _, _ = model_attn.inference(
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
)
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
lib_output = lightning_attn_func(q, k, v, slope_rate)
lib_output = lib_output.transpose(1, 2).contiguous()
lib_output = lib_output.view(batch_size, seq_len, -1)
lib_output = model_attn.norm(lib_output)
lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
lib_output = model_attn.out_proj(lib_output)
torch.testing.assert_close(
model_output,
lib_output,
rtol=1e-3,
atol=1e-2,
msg="Lightning attention implementations produce different results",
)
print("✅ Two implementations match")
def get_benchmark():
batch_size_range = [2**i for i in range(0, 7)] # max 64
seq_length_range = [256, 512, 1024, 2048, 4096] # max 4096
configs = list(itertools.product(batch_size_range, seq_length_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["MiniMax-Text-01", "OpenNLPLab"],
line_names=[
"MiniMax-Text-01 Model Implementation",
"OpenNLPLab Library Implementation",
],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="lightning-attention-prefill-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "gelu",
}
hidden_states = torch.randn(
batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
)
attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)
slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
model_attn.eval()
quantiles = [0.5, 0.2, 0.8]
if provider == "MiniMax-Text-01":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: model_attn.inference(
hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
),
quantiles=quantiles,
)
else:
def run_lib():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
lib_output = lightning_attn_func(q, k, v, slope_rate)
lib_output = lib_output.transpose(1, 2).contiguous()
lib_output = lib_output.view(batch_size, seq_len, -1)
lib_output = model_attn.norm(lib_output)
lib_output = (
torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
)
return model_attn.out_proj(lib_output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_lib,
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_prefill/",
help="Path to save lightning attention prefill benchmark results",
)
args = parser.parse_args()
# Run correctness test first
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
params = {
"hidden_size": 6144,
"num_attention_heads": 64,
"head_dim": 96,
"hidden_act": "silu",
}
test_lightning_attention_implementations(params)
# Run performance benchmark
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=args.save_path)

View File

@@ -0,0 +1,133 @@
import argparse
import itertools
import torch
import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from sgl_kernel.elementwise import silu_and_mul
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
from sglang.srt.layers.quantization import deep_gemm_wrapper
def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
out1, blk_scales1 = scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
)
torch.testing.assert_close(out, out1)
torch.testing.assert_close(blk_scales, blk_scales1)
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")
NUM_RANKS = 48
M_PER_RANKs = [128, 256, 512, 1024]
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
Ks = [2048, 4096, 7168]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K"],
x_vals=list(itertools.product(Ms, Ks)),
x_log=False,
line_arg="provider",
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
ylabel="ms",
plot_name="fp4 quant",
args={},
)
)
def benchmark(M, K, provider):
E = 6
device = "cuda"
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
fp8_out = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2,
),
device=x.device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
fp8_scales = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2 // scale_block_size,
),
device=x.device,
dtype=torch.float32,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "triton_fp8":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_masked_post_quant_fwd(
x,
fp8_out,
fp8_scales,
scale_block_size,
masks,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
),
quantiles=quantiles,
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_scaled_fp4_grouped_quant(
x,
glb_scales,
masks,
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
def test_accuracy():
E = 6
N_RANKS = 48
Ms = [128, 256, 512, 1024]
Ks = [2048, 4096, 7168]
input_dtype = torch.bfloat16
for M in Ms:
for K in Ks:
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_fp4_quant_res",
help="Path to save fp4 quant benchmark results",
)
args = parser.parse_args()
test_accuracy()
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)

View File

@@ -0,0 +1,94 @@
import argparse
import torch
import triton
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
@torch.compile(backend="inductor")
def torch_int8_quant(x):
int8_max = torch.iinfo(torch.int8).max
abs_max = x.abs().max(dim=-1, keepdim=True).values
scales = abs_max.to(torch.float32) / float(int8_max)
q_x = (x / scales).round().to(torch.int8)
return q_x, scales
def _test_accuracy_once(M, K, input_dtype, device):
x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000
out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True)
out1, scales1 = per_token_quant_int8(x)
out2, scales2 = torch_int8_quant(x)
torch.testing.assert_close(out, out2, atol=1, rtol=0)
torch.testing.assert_close(out, out1, atol=1, rtol=0)
torch.testing.assert_close(scales, scales2)
torch.testing.assert_close(scales1, scales2)
print(f"M: {M}, K: {K}, type: {input_dtype} OK")
def test_accuracy():
Ms = [1, 13, 128, 1024, 2048, 4096]
Ks = [512, 1024, 2048, 8192]
input_dtypes = [torch.float16, torch.bfloat16]
for M in Ms:
for K in Ks:
for input_dtype in input_dtypes:
_test_accuracy_once(M, K, input_dtype, "cuda")
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["vllm op", "triton", "torch.compile"],
line_names=["vllm op", "triton", "torch.compile"],
styles=[("blue", "-"), ("orange", "-"), ("red", "-")],
ylabel="ms",
plot_name="int8 per token quant",
args={},
)
)
def benchmark(batch_size, provider):
M, K = batch_size, 16384
x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm op":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_int8_quant(x, symmetric=True),
quantiles=quantiles,
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: per_token_quant_int8(x),
quantiles=quantiles,
)
if provider == "torch.compile":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch_int8_quant(x),
quantiles=quantiles,
)
return ms, min_ms, max_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_int8_quant_res",
help="Path to save int8 quant benchmark results",
)
args = parser.parse_args()
test_accuracy()
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)

View File

@@ -0,0 +1,474 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import json
import multiprocessing as mp
import os
import time
from datetime import datetime
from typing import Any, Dict, List
import torch
import triton
from tqdm import tqdm
mp.set_start_method("spawn", force=True)
from sglang.srt.layers.quantization.fp8_kernel import (
_w8a8_block_fp8_matmul,
_w8a8_block_fp8_matmul_unrolledx4,
)
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
_is_hip = is_hip()
DTYPE_MAP = {
"float32": torch.float32,
"float16": torch.float16,
"half": torch.half,
"bfloat16": torch.bfloat16,
}
def w8a8_block_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
config: Dict[str, Any],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
else:
kernel = _w8a8_block_int8_matmul
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
def get_rocm_configs_compute_bound():
configs = []
waves_per_eu_range = 0
for num_stages in [2]:
for block_m in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for block_n in [16, 32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 4, 8, 16, 32]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu_range,
}
)
return configs
def get_configs_compute_bound():
configs = []
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append(
{
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
}
)
return configs
def get_weight_shapes(tp_size):
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
return weight_shapes
def benchmark_config(
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
):
def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize()
# JIT complication & warmup
for _ in range(5):
run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for i in range(num_iters):
torch.cuda.synchronize()
start_event.record()
run()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
return avg
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
factor_for_scale = 1e-2
if input_type == "fp8":
fp8_info = torch.finfo(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
else:
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
)
A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
)
B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
Bs = (
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
* factor_for_scale
)
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
A,
B,
As,
Bs,
block_size,
config,
out_dtype,
num_iters=10,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
assert best_config is not None
return best_config
def save_configs(
N,
K,
block_n,
block_k,
configs,
save_path,
input_type="fp8",
) -> None:
os.makedirs(save_path, exist_ok=True)
device_name = get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json"
config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...")
with open(config_file_path, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_available_gpu_count():
"""Get the number of available GPUs."""
return torch.cuda.device_count()
def tune_on_gpu(args_dict):
"""Run tuning on a specific GPU."""
gpu_id = args_dict["gpu_id"]
batch_sizes = args_dict["batch_sizes"]
weight_shapes = args_dict["weight_shapes"]
args = args_dict["args"]
torch.cuda.set_device(gpu_id)
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
block_n = args.block_n
block_k = args.block_k
out_dtype = DTYPE_MAP[args.out_dtype]
save_path = args.save_path
input_type = args.input_type
search_space = get_configs_compute_bound()
search_space = [
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
]
start = time.perf_counter()
results = {}
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
N, K = shape[0], shape[1]
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
benchmark_results = [
tune(
batch_size,
N,
K,
[block_n, block_k],
out_dtype,
search_space,
input_type,
)
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
]
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
end = time.perf_counter()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
def distribute_batch_sizes(batch_sizes, num_gpus):
"""Distribute batch sizes across available GPUs."""
batches_per_gpu = []
for i in range(num_gpus):
start_idx = i * len(batch_sizes) // num_gpus
end_idx = (i + 1) * len(batch_sizes) // num_gpus
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
return batches_per_gpu
def main(args):
print(args)
num_gpus = get_available_gpu_count()
if num_gpus == 0:
raise RuntimeError("No GPU available for tuning")
print(f"Found {num_gpus} GPUs for parallel tuning")
torch.cuda.init()
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = [args.batch_size]
num_gpus = 1 # If only one batch size, use only one GPU
weight_shapes = get_weight_shapes(args.tp_size)
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
process_args = []
for gpu_id in range(num_gpus):
process_args.append(
{
"gpu_id": gpu_id,
"batch_sizes": batches_per_gpu[gpu_id],
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
"args": args,
}
)
ctx = mp.get_context("spawn")
with ctx.Pool(num_gpus) as pool:
pool.map(tune_on_gpu, process_args)
print("Multi-GPU tuning completed")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument(
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
)
parser.add_argument(
"--out-dtype",
type=str,
choices=["float32", "float16", "bfloat16", "half"],
default="float16",
)
parser.add_argument("--block-n", type=int, default=128)
parser.add_argument("--block-k", type=int, default=128)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument(
"--save-path", type=str, default="python/sglang/srt/layers/quantization/configs"
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,230 @@
import itertools
from typing import Optional, Tuple, Union
import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn
from vllm import _custom_ops as vllm_ops
class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def rmsnorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
naive_norm.weight = nn.Parameter(weight)
naive_norm = naive_norm.to(x.device)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
output = naive_norm(x, residual)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_flashinfer(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
fused_add_rmsnorm(x, residual, weight, eps)
output = (x, residual)
else:
output = rmsnorm(x, weight, eps)
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def rmsnorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
output = (x, residual)
else:
out = torch.empty_like(x)
vllm_ops.rms_norm(out, x, weight, eps)
output = out
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_vllm = rmsnorm_vllm(
x.clone(), weight, residual.clone() if residual is not None else None
)
if use_residual:
output_naive = output_naive[0]
output_flashinfer = output_flashinfer[0]
output_vllm = output_vllm[0]
print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")
if torch.allclose(
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48]
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm"],
line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
args={},
)
)
def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
quantiles = [0.5, 0.2, 0.8]
if provider == "huggingface":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_flashinfer(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--use_residual", action="store_true", help="Whether to use residual connection"
)
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/rmsnorm/",
help="Path to save rmsnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
)
# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)

View File

@@ -0,0 +1,169 @@
import os
import torch
import triton
import triton.language as tl
@torch.compile(dynamic=True)
def get_last_loc_torch(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
return torch.where(
prefix_lens_tensor > 0,
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
torch.full_like(prefix_lens_tensor, -1),
)
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 256
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result
def test_get_last_loc():
max_batch = 4097
max_context_len = 6148
batch_size = 20
# Initialize input tensors
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
pre_lens = torch.randint(
-max_context_len // 2,
max_context_len,
(batch_size,),
dtype=torch.int64,
device="cuda",
)
last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens)
last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens)
# Compare results
torch.testing.assert_close(last_loc_res, last_loc_ref)
def get_benchmark():
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=batch_sizes,
line_arg="provider",
line_vals=["reference", "triton"],
line_names=["PyTorch", "Triton"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="get-last-loc-performance",
args={},
)
)
def benchmark(batch_size, provider):
max_batch = 2048
max_context_len = 16384
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda")
pre_lens = torch.randint(
-max_context_len // 2,
max_context_len,
(batch_size,),
dtype=torch.int64,
device="cuda",
)
quantiles = [0.5, 0.2, 0.8]
if provider == "reference":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def run_benchmark(save_path: str = "./configs/benchmark_ops/get_last_loc/"):
"""Run benchmark and save results"""
# Ensure save path exists
os.makedirs(save_path, exist_ok=True)
# Run correctness test
test_get_last_loc()
print("Correctness test passed!")
# Run performance test
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=save_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/get_last_loc/",
help="Path to save benchmark results",
)
args = parser.parse_args()
run_benchmark(args.save_path)

View File

@@ -0,0 +1,342 @@
import itertools
import os
import torch
import triton
import triton.language as tl
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
@triton.jit
def write_req_to_token_pool_triton_optimize(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
pid_token = tl.program_id(1)
req_pool_index = tl.load(req_pool_indices + pid_batch)
pre_len = tl.load(pre_lens + pid_batch)
seq_len = tl.load(seq_lens + pid_batch)
extend_len = seq_len - pre_len
cumsum_start = 0
for i in range(pid_batch):
cumsum_start += tl.load(extend_lens + i)
token_start = pid_token * BLOCK_SIZE
offset = tl.arange(0, BLOCK_SIZE)
actual_offset = token_start + offset
mask = actual_offset < extend_len
src_ptr = out_cache_loc + cumsum_start + actual_offset
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
value = tl.load(src_ptr, mask=mask)
dst_ptr = (
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ actual_offset
+ pre_len
)
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
tl.store(dst_ptr, value, mask=mask)
def write_req_to_token_pool_reference(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
pre_lens: torch.Tensor,
seq_lens: torch.Tensor,
extend_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
) -> None:
"""Reference implementation using PyTorch"""
for i in range(len(req_pool_indices)):
req_pool_idx = req_pool_indices[i].item()
pre_len = pre_lens[i].item()
seq_len = seq_lens[i].item()
extend_len = extend_lens[i].item()
cumsum_start = sum(extend_lens[:i].tolist())
# Copy values from out_cache_loc to req_to_token
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
cumsum_start : cumsum_start + extend_len
]
def test_write_req_to_token_pool():
max_batch = 4097
max_context_len = 6148
batch_size = 1
extend_len = 14
# Initialize input tensors
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
# Create copies for reference implementation
req_to_token_ref = req_to_token.clone()
req_to_token_opt = req_to_token.clone()
# Run original triton kernel
write_req_to_token_pool_triton[(batch_size,)](
req_to_token,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
)
# Run optimized triton kernel
def grid(batch_size, extend_len):
num_token_blocks = triton.cdiv(extend_len, 512)
return (batch_size, num_token_blocks)
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
req_to_token_opt,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=512,
)
# Run reference implementation
write_req_to_token_pool_reference(
req_to_token_ref,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
)
# Compare results
torch.testing.assert_close(req_to_token, req_to_token_ref)
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
# Test case 2: batch size > 1
batch_size = 3
extend_lens_list = [14, 20, 30]
total_extend_len = sum(extend_lens_list)
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
req_to_token_ref = req_to_token.clone()
req_to_token_opt = req_to_token.clone()
# Run original triton kernel
write_req_to_token_pool_triton[(batch_size,)](
req_to_token,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
)
# Run optimized triton kernel
max_extend_len = max(extend_lens_list)
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
req_to_token_opt,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=512,
)
# Run reference implementation
write_req_to_token_pool_reference(
req_to_token_ref,
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
)
# Compare results
torch.testing.assert_close(req_to_token, req_to_token_ref)
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
def get_benchmark():
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(batch_sizes, extend_lens))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "extend_len"],
x_vals=configs,
line_arg="provider",
line_vals=["reference", "triton", "triton_optimize"],
line_names=["PyTorch", "Triton", "Triton Optimized"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="write-req-to-token-pool-performance",
args={},
)
)
def benchmark(batch_size, extend_len, provider):
max_batch = 256
max_context_len = 16384
extend_lens_list = [extend_len] * batch_size
total_extend_len = sum(extend_lens_list)
req_to_token = torch.zeros(
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
)
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
seq_lens = pre_lens + extend_len
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "reference":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: write_req_to_token_pool_reference(
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: write_req_to_token_pool_triton[(batch_size,)](
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
),
quantiles=quantiles,
)
else:
def run_optimized():
block_size = 128 if extend_len <= 1024 else 512
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
write_req_to_token_pool_triton_optimize[grid_config](
req_to_token.clone(),
req_pool_indices,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
max_context_len,
BLOCK_SIZE=block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(
run_optimized, quantiles=quantiles
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
"""Run benchmark and save results"""
# Ensure save path exists
os.makedirs(save_path, exist_ok=True)
# Run correctness test
test_write_req_to_token_pool()
print("Correctness test passed!")
# Run performance test
benchmark = get_benchmark()
benchmark.run(print_data=True, save_path=save_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/write_req_to_token_pool/",
help="Path to save benchmark results",
)
args = parser.parse_args()
run_benchmark(args.save_path)

View File

@@ -0,0 +1,283 @@
import itertools
import torch
import torch.nn.functional as F
import triton.testing as tt
from sglang.srt.layers.attention.triton_ops.extend_attention import extend_attention_fwd
def extend_attention_fwd_torch(
q: torch.Tensor, # [extend_tokens, H_Q, D]
k: torch.Tensor, # [extend_tokens, H_KV, D]
v: torch.Tensor, # [extend_tokens, H_KV, D]
o: torch.Tensor, # [extend_tokens, H_Q, D]
k_cache: torch.Tensor, # [total_tokens, H_KV, D]
v_cache: torch.Tensor, # [total_tokens, H_KV, D]
qo_indptr: torch.Tensor, # [B+1]
kv_indptr: torch.Tensor, # [B+1]
kv_indices: torch.Tensor, # [prefix_tokens]
sliding_window_size: int,
):
B = qo_indptr.size(0) - 1
_, H_Q, D = q.shape
_, H_KV, _ = k.shape
group_size = H_Q // H_KV
scale = 1.0 / D**0.5
for i in range(B):
q_start = int(qo_indptr[i].item())
q_end = int(qo_indptr[i + 1].item())
kv_start = int(kv_indptr[i].item())
kv_end = int(kv_indptr[i + 1].item())
prefix_indices = kv_indices[kv_start:kv_end]
k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D]
v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D]
k_extend = k[q_start:q_end] # [extend_len, H_KV, D]
v_extend = v[q_start:q_end] # [extend_len, H_KV, D]
q_extend = q[q_start:q_end] # [extend_len, H_Q, D]
k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D]
v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D]
if group_size != 1:
k_full_hq = k_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
v_full_hq = v_full.repeat_interleave(
group_size, dim=1
) # [total_len, H_Q, D]
else:
k_full_hq = k_full
v_full_hq = v_full
prefix_len = k_prefix.size(0)
extend_len = k_extend.size(0)
total_len = prefix_len + extend_len
# causal
pos_keys = torch.arange(total_len, device=q.device)
t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len]
causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1)
# sliding window
if sliding_window_size is not None and sliding_window_size > 0:
start = (t - (sliding_window_size)).clamp_min(0) # [extend_len]
else:
start = torch.zeros_like(t)
window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1)
final_mask = causal_mask & window_mask
attn_scores = (
torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale
) # [extend_len, H_Q, total_len]
attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq)
def _build_batch(
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=torch.bfloat16, device="cuda"
):
b_seq_len_prefix = torch.randint(
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
)
b_seq_len_extend = torch.randint(
1, max(2, N_CTX // 2), (B,), dtype=torch.int32, device=device
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(int(b_seq_len_prefix.sum().item()),), dtype=torch.int32, device=device
)
for i in range(B):
s = kv_indptr[i].item()
e = kv_indptr[i + 1].item()
kv_indices[s:e] = torch.arange(
b_start_loc[i],
b_start_loc[i] + b_seq_len_prefix[i],
dtype=torch.int32,
device=device,
)
total_token_num = int(torch.sum(b_seq_len).item())
extend_token_num = int(torch.sum(b_seq_len_extend).item())
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device=device)
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(int(b_seq_len_extend[i].item()), H_Q, D), dtype=dtype, device=device
).normal_(mean=0.1, std=0.2)
o_extend_triton = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device=device
)
o_extend_torch = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device=device)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
max_len_extend = int(torch.max(b_seq_len_extend, 0)[0].item())
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device)
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
inputs = dict(
q_extend=q_extend,
k_extend=k_extend,
v_extend=v_extend,
k_buffer=k_buffer,
v_buffer=v_buffer,
o_extend_triton=o_extend_triton,
o_extend_torch=o_extend_torch,
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
max_len_extend=max_len_extend,
WINDOW_SIZE=WINDOW_SIZE,
)
meta = dict(
B=B, N_CTX=N_CTX, H_Q=H_Q, H_KV=H_KV, D=D, extend_token_num=extend_token_num
)
return inputs, meta
def _run_triton(inputs):
extend_attention_fwd(
inputs["q_extend"],
inputs["k_extend"],
inputs["v_extend"],
inputs["o_extend_triton"],
inputs["k_buffer"],
inputs["v_buffer"],
inputs["qo_indptr"],
inputs["kv_indptr"],
inputs["kv_indices"],
custom_mask=None,
is_causal=True,
mask_indptr=None,
max_len_extend=inputs["max_len_extend"],
sliding_window_size=inputs["WINDOW_SIZE"],
)
def _run_torch_ref(inputs):
extend_attention_fwd_torch(
inputs["q_extend"],
inputs["k_extend"],
inputs["v_extend"],
inputs["o_extend_torch"],
inputs["k_buffer"],
inputs["v_buffer"],
inputs["qo_indptr"],
inputs["kv_indptr"],
inputs["kv_indices"],
inputs["WINDOW_SIZE"],
)
N_CTXS = [1024, 2048, 4096, 8192]
WINDOW_SIZES = [-1, 127, 256, 512]
CONFIGS = list(itertools.product(N_CTXS, WINDOW_SIZES))
PROVIDERS = ["torch", "triton"]
@tt.perf_report(
tt.Benchmark(
x_names=["N_CTX", "WINDOW_SIZE"],
x_vals=CONFIGS,
line_arg="provider",
line_vals=PROVIDERS,
line_names=PROVIDERS,
ylabel="Runtime (ms)",
plot_name="extend_attention_triton_vs_torch",
args={
"B": 32,
"H_Q": 64,
"H_KV": 8,
"D": 128,
"dtype": "bf16",
"device": "cuda",
"check_correctness": False,
"warmup": 25,
"rep": 100,
},
)
)
def bench(
N_CTX,
provider,
B,
H_Q,
H_KV,
D,
dtype,
device,
WINDOW_SIZE,
check_correctness,
warmup,
rep,
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
dt = dtype_map[dtype]
inputs, _ = _build_batch(
B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE, dtype=dt, device=device
)
if check_correctness and provider == "triton":
_run_triton(inputs)
_run_torch_ref(inputs)
torch.cuda.synchronize()
if not torch.allclose(
inputs["o_extend_triton"], inputs["o_extend_torch"], rtol=1e-3, atol=1e-3
):
raise AssertionError("Mismatch between triton and torch reference.")
if provider == "triton":
ms = tt.do_bench(lambda: _run_triton(inputs), warmup=warmup, rep=rep)
elif provider == "torch":
ms = tt.do_bench(lambda: _run_torch_ref(inputs), warmup=warmup, rep=rep)
else:
raise ValueError(provider)
return ms
if __name__ == "__main__":
bench.run(print_data=True, show_plots=False)

View File

@@ -0,0 +1,37 @@
## Download data
```
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
python3 gen_data.py --number 1000
```
## Run benchmark
### Benchmark sglang
```
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
```
```
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
```
###
```
# original
Accuracy: 0.940, latency: 332.83 s
# parallel encoding (no_adjust, offset = 1000)
Accuracy: 0.760, latency: 238.46 s
# parallel encoding (no_adjust, offset = 3000)
Accuracy: 0.760, latency: 238.46 s
# parallel encoding (no_adjust, offset = 0)
Accuracy: 0.520, latency: 238.46 s
# parallel encoding (adjust_cache)
Accuracy: 0.460, latency: 257.66 s
```

View File

@@ -0,0 +1,149 @@
import argparse
import json
import re
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
@sgl.function
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
s += prefix + "\n"
contexts = [body_0, body_1, body_2, body_3]
position_ids_offset = [i * 1000 for i in range(len(contexts))]
forks = s.fork(len(contexts), position_ids_offset)
forks += lambda i: contexts[i] + "\n"
forks.join(mode="concate_and_append")
s += "\n" + suffix
s += sgl.gen("answer", max_tokens=16)
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
arguments = []
labels = []
sum_src_indices = []
sum_dst_indices = []
for i in range(len(src_indices)):
for j in range(len(dst_percents)):
src_index = src_indices[i]
dst_percent = dst_percents[j]
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
query_indices = [
q
for q in query_indices
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
]
dst_index = query_indices[
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
]
label = line_obj["values"][dst_index]
body = line_obj["lines"][: src_index + 1]
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
body_part_len = len(body) // 4
arguments.append(
{
"prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len :]),
"suffix": suffix,
}
)
labels.append(label)
sum_src_indices.append(src_index)
sum_dst_indices.append(dst_index)
# Select backend
backend = select_sglang_backend(args)
tic = time.perf_counter()
states = line_retrieval.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
corrects = []
for i in range(len(arguments)):
output = states[i]["answer"]
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
label = labels[i]
# Try all numbers
findall = re.findall("\d+", output)
if not findall:
response_number = output
else:
for response_number in findall:
if response_number == label:
break
correct = response_number == label
corrects.append(correct)
# Log results
summary = (
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
f"Prompt len: {prompt_len}, "
f"Correct: {correct}, "
f"Label: {label}, Predicted: {response_number}, "
)
print(summary)
accuracy = np.mean(corrects)
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "line_retrieval",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(arguments),
"other": {
"num_questions": len(arguments),
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
def main(args):
line_obj = json.load(open(args.data_path, "r"))
num_hoops = args.num_hoops
for src_index in args.src_index:
src_indices = [src_index]
num_queries = args.num_queries_per_src
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
parser.add_argument("--num-queries-per-src", type=int, default=10)
parser.add_argument("--num-hoops", type=int, default=1)
args = add_common_sglang_args_and_parse(parser)
main(args)

Some files were not shown because too many files have changed in this diff Show More