sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
425
3rdparty/amd/profiling/PROFILING.md
vendored
Normal file
425
3rdparty/amd/profiling/PROFILING.md
vendored
Normal file
@@ -0,0 +1,425 @@
|
||||
## Profiling SGLang Infer System with AMD GPUs
|
||||
This AppNote describes the SGLang profiling technical, code augment and running steps for systems with AMD Instinct GPUs, nevertheless the same procedure may work with Nvidia GPUs too.
|
||||
Examples and steps are provided in detail, to facilitate easy reproduce and use to localize performance problem towards optimizations.
|
||||
Two primary methods are covered:
|
||||
- [RPD](https://github.com/ROCm/rocmProfileData.git)
|
||||
- [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
||||
|
||||
### Profiling SGLang Infer System with RPD Profiler
|
||||
RPD profiler is a low-overhead cross-platform profiler. Therefore, the same RPD code augment not only works for profiling on ROCm/AMD GPUs, but also works for profiling on CUDA/Nvidia GPUs as well. To do RPD profiling on SGLang repository, please use scripts and patch files included in this directory and follow the steps below:
|
||||
1. Install RPD with rpd.patch applied during installation using install_rpd.sh, both files are in this directory.
|
||||
|
||||
install_rpd.sh
|
||||
|
||||
```bash
|
||||
# download and install RPD
|
||||
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
|
||||
|
||||
# install rpd module
|
||||
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
|
||||
cd rocmProfileData
|
||||
git checkout 976899e9c6dbc6dd2bccf770818e4e44125590ac
|
||||
git apply rpd.patch
|
||||
make && make install
|
||||
cd rocpd_python && python setup.py install && cd ..
|
||||
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
|
||||
```
|
||||
|
||||
rpd.patch
|
||||
|
||||
```bash
|
||||
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
|
||||
index e9d9feb..b2e9e1a 100644
|
||||
--- a/rpd_tracer/Makefile
|
||||
+++ b/rpd_tracer/Makefile
|
||||
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
|
||||
$(info Building with roctracer)
|
||||
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
|
||||
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
|
||||
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
|
||||
+ RPD_SRCS += RoctracerDataSource.cpp
|
||||
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
|
||||
endif
|
||||
```
|
||||
2. Add loadTracer.sh file included in this directory to /sglang/python/sglang.
|
||||
|
||||
loadTracer.sh
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
################################################################################
|
||||
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
################################################################################
|
||||
OUTPUT_FILE="trace.rpd"
|
||||
|
||||
if [ "$1" = "-o" ] ; then
|
||||
OUTPUT_FILE=$2
|
||||
shift
|
||||
shift
|
||||
fi
|
||||
|
||||
if [ -e ${OUTPUT_FILE} ] ; then
|
||||
rm ${OUTPUT_FILE}
|
||||
fi
|
||||
|
||||
python3 -m rocpd.schema --create ${OUTPUT_FILE}
|
||||
if [ $? != 0 ] ; then
|
||||
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
|
||||
exit
|
||||
fi
|
||||
|
||||
export RPDT_FILENAME=${OUTPUT_FILE}
|
||||
export RPDT_AUTOSTART=0
|
||||
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
|
||||
```
|
||||
3. Apply patch (provided in this directory) with "git apply rpd_profile_server_enable.patch" if the main profiling purpose is to get info on gpu kernels as well as limited cpu activity info.
|
||||
|
||||
#### Common Notes 1
|
||||
Please note that although we are doing TP=8 in the example, we purposely only log RPD profiling on 2 ranks in the patch file (i.e.tp_rank=0/1) for profiling/visualization convenience, as even Perfetto streaming mode can only load maximal 8GB json file for visualization. With 2 ranks logged in RPD profiling, we could still check whether there are issues among ranks (e.g. load imbalance issue, nccl issue), and at the same time, we could log relatively longer time duration before the json file generated from RPD file hits 8GB size.
|
||||
|
||||
rpd_profile_server_enable.patch
|
||||
|
||||
```bash
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 62d1ff9..9021c01 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -245,6 +247,7 @@ class Scheduler:
|
||||
],
|
||||
with_stack=True,
|
||||
)
|
||||
+ self.rpd = rpdTracerControl()
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
@@ -1027,15 +1030,24 @@ class Scheduler:
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.start()
|
||||
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
|
||||
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||
+ self.rpd.start()
|
||||
+ self.rpd.rangePush("", "rpd profile range", "")
|
||||
+ logger.info("rpd is enabled")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.stop()
|
||||
- self.profiler.export_chrome_trace(
|
||||
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
- )
|
||||
+ #self.profiler.stop()
|
||||
+ #self.profiler.export_chrome_trace(
|
||||
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
+ #)
|
||||
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||
+ self.rpd.rangePop()
|
||||
+ self.rpd.stop()
|
||||
+ self.rpd.flush()
|
||||
+ logger.info("rpd is done")
|
||||
logger.info("Profiler is done")
|
||||
```
|
||||
|
||||
#### Advanced Debugging with RPD Profiler
|
||||
Sometimes, we want to use rpd profiler to capture more CPU and python activities in order to debug some challenging issues (e.g. root cause of load imbalance across gpu processes, root cause of bubbles, etc). Only in such cases, we need to apply patch "git apply rpd_profile_server_enable_wCPU_activities.patch", where 3 files are modified.
|
||||
|
||||
rpd_profile_server_enable_wCPU_activities.patch
|
||||
|
||||
```bash
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 62d1ff9..2edb427 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -245,6 +247,7 @@ class Scheduler:
|
||||
],
|
||||
with_stack=True,
|
||||
)
|
||||
+ self.rpd = rpdTracerControl()
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
@@ -1027,15 +1030,26 @@ class Scheduler:
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.start()
|
||||
+ #self.profiler.start()
|
||||
+ logger.info("torch profiler is disabled")
|
||||
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||
+ self.rpd.setPythonTrace(True)
|
||||
+ self.rpd.start()
|
||||
+ self.rpd.rangePush("", "scheduler", "")
|
||||
+ logger.info("rpd is enabled inside scheduler profiling")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.stop()
|
||||
- self.profiler.export_chrome_trace(
|
||||
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
- )
|
||||
+ #self.profiler.stop()
|
||||
+ #self.profiler.export_chrome_trace(
|
||||
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
+ #)
|
||||
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||
+ self.rpd.rangePop()
|
||||
+ self.rpd.stop()
|
||||
+ self.rpd.flush()
|
||||
+ logger.info("rpd is done inside scheduler")
|
||||
logger.info("Profiler is done")
|
||||
|
||||
|
||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
index 2621ccd..181df85 100644
|
||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
||||
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
+
|
||||
+
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -514,10 +518,20 @@ class TokenizerManager:
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_profile(self):
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.setPythonTrace(True)
|
||||
+ rpd.start()
|
||||
+ rpd.rangePush("", "tokenizer_manager", "")
|
||||
+ logger.info("tokenizer_manager rpd profiling started!")
|
||||
req = ProfileReq.START_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def stop_profile(self):
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.rangePop()
|
||||
+ rpd.stop()
|
||||
+ rpd.flush()
|
||||
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
||||
req = ProfileReq.STOP_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
||||
index 7111c93..2bd722c 100644
|
||||
--- a/python/sglang/srt/server.py
|
||||
+++ b/python/sglang/srt/server.py
|
||||
@@ -30,6 +30,8 @@ import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -152,6 +154,11 @@ async def flush_cache():
|
||||
@app.post("/start_profile")
|
||||
async def start_profile():
|
||||
"""Start profiling."""
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.setPythonTrace(True)
|
||||
+ rpd.start()
|
||||
+ rpd.rangePush("", "server rpd profile range", "")
|
||||
+ logger.info("rpd profiling started in server.py!")
|
||||
tokenizer_manager.start_profile()
|
||||
return Response(
|
||||
content="Start profiling.\n",
|
||||
@@ -164,6 +171,11 @@ async def start_profile():
|
||||
async def stop_profile():
|
||||
"""Stop profiling."""
|
||||
tokenizer_manager.stop_profile()
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.rangePop()
|
||||
+ rpd.stop()
|
||||
+ rpd.flush()
|
||||
+ logger.info("rpd profiling is done in server.py!")
|
||||
return Response(
|
||||
content="Stop profiling. This will take some time.\n",
|
||||
status_code=200,
|
||||
```
|
||||
|
||||
4. As an example for grok1 profiling, we create a dummy_grok1 directory with config.json (see content below) inside this directory and copy this directory to the right path for "--model-path" if you want to use the example server.sh file provided.
|
||||
```bash
|
||||
cat ../dummy_grok1/config.json
|
||||
{
|
||||
"architectures": [
|
||||
"Grok1ModelForCausalLM"
|
||||
],
|
||||
"embedding_multiplier_scale": 78.38367176906169,
|
||||
"output_multiplier_scale": 0.5773502691896257,
|
||||
"vocab_size": 131072,
|
||||
"hidden_size": 6144,
|
||||
"intermediate_size": 32768,
|
||||
"max_position_embeddings": 8192,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": 8,
|
||||
"num_attention_heads": 48,
|
||||
"num_hidden_layers": 64,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_theta": 10000.0,
|
||||
"model_type": "mixtral",
|
||||
"torch_dtype": "bfloat16"
|
||||
}
|
||||
```
|
||||
5. Launch server with rpd enabled script ./server.sh in one terminal inside the docker container.
|
||||
|
||||
#### Common Notes 2
|
||||
- Remember to change model-path to the correct path
|
||||
- loadTracer.sh is needed to conduct profiling
|
||||
- SGLANG_TORCH_PROFILER_DIR is used for default torch profiler
|
||||
- Do not use loadTracer.sh if you are using the torch profiler, simply use python3 -m sglang.launch_server.
|
||||
|
||||
|
||||
server.sh
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
|
||||
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
|
||||
|
||||
# Get the current timestamp
|
||||
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
||||
|
||||
# Define the log file with a timestamp
|
||||
LOGFILE="sglang_server_log_$TIMESTAMP.json"
|
||||
|
||||
# Run the Python command and save the output to the log file
|
||||
loadTracer.sh python3 -m sglang.launch_server \
|
||||
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
||||
--tokenizer-path Xenova/grok-1-tokenizer \
|
||||
--load-format dummy \
|
||||
--quantization fp8 \
|
||||
--tp 8 \
|
||||
--port 30000 \
|
||||
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
||||
```
|
||||
6. Open another terminal for the same docker container, and run the rpd enabled ./client.sh after you see "The server is fired up and is ready to roll!" message from server side terminal.
|
||||
|
||||
#### Common Notes 3
|
||||
- Use curl http://localhost:30000/start_profile & curl http://localhost:30000/stop_profile to control the start and end of profiling. Check sglang/python/sglang/srt/managers/scheduler.py for more details.
|
||||
- Please don't use RPD profiler together with PyTorch profiler to avoid interference.
|
||||
- The rocmProfileData/tools/rpd2tracing.py file is used to generate json file from RPD file.
|
||||
|
||||
client.sh
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
# Start profiling via API
|
||||
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
|
||||
|
||||
# Benchmark serving using sglang with random dataset and tokenizer
|
||||
# Define the log file with a timestamp
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
LOGFILE="sglang_client_log_$TIMESTAMP.json"
|
||||
|
||||
# Run the benchmark with specified parameters and save logs
|
||||
python3 -m sglang.bench_serving \
|
||||
--backend sglang \
|
||||
--tokenizer Xenova/grok-1-tokenizer \
|
||||
--dataset-name random \
|
||||
--random-input 1024\
|
||||
--random-output 1024 \
|
||||
--num-prompts 120 \
|
||||
--request-rate 8 \
|
||||
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
|
||||
|
||||
# Stop profiling via API
|
||||
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
|
||||
|
||||
# Convert tracing file to csv & json
|
||||
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
|
||||
python3 ./rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
|
||||
```
|
||||
7. Follow [Perfetto docs](https://perfetto.dev/docs/visualization/large-traces) to visualize large json files. Try to adjust parameters so that the trace.json file size is less than 9GB.
|
||||
|
||||
### Profiling SGLang Infer System with PyTorch Profiler
|
||||
|
||||
Please use the steps as follows:
|
||||
|
||||
1. Apply the patch torch_profiler.patch. Note that you can modify "if self.tp_rank == 0" in the patch to allow more ranks be recorded in profiling.
|
||||
|
||||
torch_profiler.patch
|
||||
```bash
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 62d1ff9..6ecd78c 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -240,7 +240,6 @@ class Scheduler:
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
- torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
@@ -1033,9 +1032,11 @@ class Scheduler:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
- self.profiler.export_chrome_trace(
|
||||
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
- )
|
||||
+ if self.tp_rank == 0:
|
||||
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
|
||||
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
|
||||
+ print("Profiling stats done.")
|
||||
+
|
||||
logger.info("Profiler is done")
|
||||
```
|
||||
|
||||
2. Create the model path directory and copy it to the right path for "--model-path" if you want to use the server.sh file provided.
|
||||
|
||||
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
|
||||
|
||||
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
|
||||
-------
|
||||
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
||||
27
3rdparty/amd/profiling/client.sh
vendored
Executable file
27
3rdparty/amd/profiling/client.sh
vendored
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Start profiling via API
|
||||
curl http://localhost:30000/start_profile -H "Content-Type: application/json"
|
||||
|
||||
# Benchmark serving using sglang with random dataset and tokenizer
|
||||
# Define the log file with a timestamp
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
LOGFILE="sglang_client_log_$TIMESTAMP.json"
|
||||
|
||||
# Run the benchmark with specified parameters and save logs
|
||||
python3 -m sglang.bench_serving \
|
||||
--backend sglang \
|
||||
--tokenizer Xenova/grok-1-tokenizer \
|
||||
--dataset-name random \
|
||||
--random-input 1024\
|
||||
--random-output 1024 \
|
||||
--num-prompts 240 \
|
||||
--request-rate 8 \
|
||||
--output-file online.jsonl 2>&1 | tee "$LOGFILE"
|
||||
|
||||
# Stop profiling via API
|
||||
curl http://localhost:30000/stop_profile -H "Content-Type: application/json"
|
||||
|
||||
# Convert tracing file to csv & json
|
||||
sqlite3 trace.rpd ".mode csv" ".header on" ".output trace.csv" "select * from top;" ".output stdout"
|
||||
python3 /sgl-workspace/rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json
|
||||
10
3rdparty/amd/profiling/install_rpd.sh
vendored
Normal file
10
3rdparty/amd/profiling/install_rpd.sh
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# download and install RPD
|
||||
apt update && apt install -y sqlite3 libsqlite3-dev libfmt-dev
|
||||
|
||||
# install rpd module
|
||||
git clone https://github.com/ROCmSoftwarePlatform/rocmProfileData
|
||||
cd rocmProfileData
|
||||
git apply rpd.patch
|
||||
make && make install
|
||||
cd rocpd_python && python setup.py install && cd ..
|
||||
cd rpd_tracer && make clean;make install && python setup.py install && cd ..
|
||||
43
3rdparty/amd/profiling/loadTracer.sh
vendored
Executable file
43
3rdparty/amd/profiling/loadTracer.sh
vendored
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/bin/bash
|
||||
################################################################################
|
||||
# Copyright (c) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
################################################################################
|
||||
OUTPUT_FILE="trace.rpd"
|
||||
|
||||
if [ "$1" = "-o" ] ; then
|
||||
OUTPUT_FILE=$2
|
||||
shift
|
||||
shift
|
||||
fi
|
||||
|
||||
if [ -e ${OUTPUT_FILE} ] ; then
|
||||
rm ${OUTPUT_FILE}
|
||||
fi
|
||||
|
||||
python3 -m rocpd.schema --create ${OUTPUT_FILE}
|
||||
if [ $? != 0 ] ; then
|
||||
echo "Error: Could not create rpd file. Please run 'python setup.py install' from the rocpd_python dir"
|
||||
exit
|
||||
fi
|
||||
|
||||
export RPDT_FILENAME=${OUTPUT_FILE}
|
||||
export RPDT_AUTOSTART=0
|
||||
LD_PRELOAD=librocm-smi_64:librpd_tracer.so "$@"
|
||||
12
3rdparty/amd/profiling/rpd.patch
vendored
Normal file
12
3rdparty/amd/profiling/rpd.patch
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
diff --git a/rpd_tracer/Makefile b/rpd_tracer/Makefile
|
||||
index e9d9feb..b2e9e1a 100644
|
||||
--- a/rpd_tracer/Makefile
|
||||
+++ b/rpd_tracer/Makefile
|
||||
@@ -16,7 +16,7 @@ ifneq (,$(HIP_PATH))
|
||||
$(info Building with roctracer)
|
||||
RPD_LIBS += -L/opt/rocm/lib -lroctracer64 -lroctx64 -lamdhip64 -lrocm_smi64
|
||||
RPD_INCLUDES += -I/opt/rocm/include -I/opt/rocm/include/roctracer -I/opt/rocm/include/hsa
|
||||
- RPD_SRCS += RoctracerDataSource.cpp RocmSmiDataSource.cpp
|
||||
+ RPD_SRCS += RoctracerDataSource.cpp
|
||||
RPD_INCLUDES += -D__HIP_PLATFORM_AMD__
|
||||
endif
|
||||
49
3rdparty/amd/profiling/rpd_profile_server_enable.patch
vendored
Normal file
49
3rdparty/amd/profiling/rpd_profile_server_enable.patch
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 62d1ff9..9021c01 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -245,6 +247,7 @@ class Scheduler:
|
||||
],
|
||||
with_stack=True,
|
||||
)
|
||||
+ self.rpd = rpdTracerControl()
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
@@ -1027,15 +1030,24 @@ class Scheduler:
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.start()
|
||||
+ #self.profiler.start() #block pytorch profiler for rpd profiler enabling
|
||||
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||
+ self.rpd.start()
|
||||
+ self.rpd.rangePush("", "rpd profile range", "")
|
||||
+ logger.info("rpd is enabled")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.stop()
|
||||
- self.profiler.export_chrome_trace(
|
||||
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
- )
|
||||
+ #self.profiler.stop()
|
||||
+ #self.profiler.export_chrome_trace(
|
||||
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
+ #)
|
||||
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||
+ self.rpd.rangePop()
|
||||
+ self.rpd.stop()
|
||||
+ self.rpd.flush()
|
||||
+ logger.info("rpd is done")
|
||||
logger.info("Profiler is done")
|
||||
126
3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch
vendored
Normal file
126
3rdparty/amd/profiling/rpd_profile_server_enable_wCPU_activities.patch
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 62d1ff9..2edb427 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -245,6 +247,7 @@ class Scheduler:
|
||||
],
|
||||
with_stack=True,
|
||||
)
|
||||
+ self.rpd = rpdTracerControl()
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
@@ -1027,15 +1030,26 @@ class Scheduler:
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.start()
|
||||
+ #self.profiler.start()
|
||||
+ logger.info("torch profiler is disabled")
|
||||
+ if self.tp_rank == 0 or self.tp_rank == 1:
|
||||
+ self.rpd.setPythonTrace(True)
|
||||
+ self.rpd.start()
|
||||
+ self.rpd.rangePush("", "scheduler", "")
|
||||
+ logger.info("rpd is enabled inside scheduler profiling")
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
- self.profiler.stop()
|
||||
- self.profiler.export_chrome_trace(
|
||||
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
- )
|
||||
+ #self.profiler.stop()
|
||||
+ #self.profiler.export_chrome_trace(
|
||||
+ # self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
+ #)
|
||||
+ if self.tp_rank ==0 or self.tp_rank ==1:
|
||||
+ self.rpd.rangePop()
|
||||
+ self.rpd.stop()
|
||||
+ self.rpd.flush()
|
||||
+ logger.info("rpd is done inside scheduler")
|
||||
logger.info("Profiler is done")
|
||||
|
||||
|
||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
index 2621ccd..181df85 100644
|
||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
@@ -58,6 +58,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
||||
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
+
|
||||
+
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -514,10 +518,20 @@ class TokenizerManager:
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_profile(self):
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.setPythonTrace(True)
|
||||
+ rpd.start()
|
||||
+ rpd.rangePush("", "tokenizer_manager", "")
|
||||
+ logger.info("tokenizer_manager rpd profiling started!")
|
||||
req = ProfileReq.START_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def stop_profile(self):
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.rangePop()
|
||||
+ rpd.stop()
|
||||
+ rpd.flush()
|
||||
+ logger.info("rpd profiling is done inside tokenizer_manager!")
|
||||
req = ProfileReq.STOP_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
|
||||
index 7111c93..2bd722c 100644
|
||||
--- a/python/sglang/srt/server.py
|
||||
+++ b/python/sglang/srt/server.py
|
||||
@@ -30,6 +30,8 @@ import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
+from rpdTracerControl import rpdTracerControl
|
||||
+rpdTracerControl.skipCreate()
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -152,6 +154,11 @@ async def flush_cache():
|
||||
@app.post("/start_profile")
|
||||
async def start_profile():
|
||||
"""Start profiling."""
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.setPythonTrace(True)
|
||||
+ rpd.start()
|
||||
+ rpd.rangePush("", "server rpd profile range", "")
|
||||
+ logger.info("rpd profiling started in server.py!")
|
||||
tokenizer_manager.start_profile()
|
||||
return Response(
|
||||
content="Start profiling.\n",
|
||||
@@ -164,6 +171,11 @@ async def start_profile():
|
||||
async def stop_profile():
|
||||
"""Stop profiling."""
|
||||
tokenizer_manager.stop_profile()
|
||||
+ rpd = rpdTracerControl()
|
||||
+ rpd.rangePop()
|
||||
+ rpd.stop()
|
||||
+ rpd.flush()
|
||||
+ logger.info("rpd profiling is done in server.py!")
|
||||
return Response(
|
||||
content="Stop profiling. This will take some time.\n",
|
||||
status_code=200,
|
||||
20
3rdparty/amd/profiling/server.sh
vendored
Executable file
20
3rdparty/amd/profiling/server.sh
vendored
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
# export SGLANG_TORCH_PROFILER_DIR=/data/sglang/
|
||||
export SGLANG_TORCH_PROFILER_DIR=/sgl-workspace/sglang/profile/
|
||||
|
||||
# Get the current timestamp
|
||||
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
||||
|
||||
# Define the log file with a timestamp
|
||||
LOGFILE="sglang_server_log_$TIMESTAMP.json"
|
||||
|
||||
# Run the Python command and save the output to the log file
|
||||
loadTracer.sh python3 -m sglang.launch_server \
|
||||
--model-path /sgl-workspace/sglang/dummy_grok1 \
|
||||
--tokenizer-path Xenova/grok-1-tokenizer \
|
||||
--load-format dummy \
|
||||
--quantization fp8 \
|
||||
--tp 8 \
|
||||
--port 30000 \
|
||||
--disable-radix-cache 2>&1 | tee "$LOGFILE"
|
||||
25
3rdparty/amd/profiling/torch_profiler.patch
vendored
Normal file
25
3rdparty/amd/profiling/torch_profiler.patch
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 62d1ff9..6ecd78c 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -240,7 +240,6 @@ class Scheduler:
|
||||
)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
- torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
@@ -1033,9 +1032,11 @@ class Scheduler:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
- self.profiler.export_chrome_trace(
|
||||
- self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
|
||||
- )
|
||||
+ if self.tp_rank == 0:
|
||||
+ with open(f"stats_repro_{int(time.time())}.txt", "w") as f:
|
||||
+ print(self.profiler.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=-1), file=f)
|
||||
+ print("Profiling stats done.")
|
||||
+
|
||||
logger.info("Profiler is done")
|
||||
118
3rdparty/amd/tuning/TUNING.md
vendored
Normal file
118
3rdparty/amd/tuning/TUNING.md
vendored
Normal file
@@ -0,0 +1,118 @@
|
||||
## Tuning SGLang Infer System with AMD GPUs
|
||||
This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
|
||||
Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
|
||||
Three primary runtime areas are covered:
|
||||
|
||||
## 1. Triton Kernels
|
||||
To maximize Triton kernel efficiency, several strategies can be employed:
|
||||
|
||||
### Key Environment Variables:
|
||||
- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
|
||||
- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
|
||||
- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
|
||||
- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
|
||||
- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
|
||||
```python
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
|
||||
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
|
||||
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
|
||||
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
|
||||
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
|
||||
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
|
||||
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
|
||||
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
|
||||
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
|
||||
], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
|
||||
@triton.jit
|
||||
def _triton_kernel_funtion():
|
||||
...
|
||||
```
|
||||
## 2. Torch Tunable Operations
|
||||
**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
|
||||
|
||||
### Key Environment Variables:
|
||||
1. **PYTORCH_TUNABLEOP_ENABLED**:
|
||||
- Default: `0`
|
||||
- Set to `1` to enable TunableOp.
|
||||
|
||||
2. **PYTORCH_TUNABLEOP_TUNING**:
|
||||
- Default: `1`
|
||||
- Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
|
||||
|
||||
3. **PYTORCH_TUNABLEOP_VERBOSE**:
|
||||
- Default: `0`
|
||||
- Set to `1` to enable verbose output for TunableOp.
|
||||
|
||||
### Usage Example:
|
||||
To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
|
||||
|
||||
```bash
|
||||
#Tuning
|
||||
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
|
||||
|
||||
#Inference with tuning op
|
||||
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
|
||||
|
||||
#Print out the log
|
||||
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
|
||||
|
||||
```
|
||||
## 3. Torch Compilation
|
||||
|
||||
|
||||
The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
|
||||
|
||||
To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
|
||||
|
||||
### Key Configurations:
|
||||
1. **Max Autotune**:
|
||||
- Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
|
||||
|
||||
2. **Fine-Grained Control**:
|
||||
- Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
|
||||
- Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
|
||||
|
||||
3. **Backend Selection**:
|
||||
- Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
|
||||
|
||||
4. **Freezing for Inference**:
|
||||
- Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
|
||||
|
||||
5. **Debugging**:
|
||||
- Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
|
||||
|
||||
### Example Code Block:
|
||||
```bash
|
||||
#Gemm Tuning
|
||||
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
|
||||
|
||||
#Specify your backend to TRITON for Gemm Tuning
|
||||
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
|
||||
|
||||
#Inference with large improvement on AMD GPU
|
||||
TORCHINDUCTOR_FREEZING=1 your_script.sh
|
||||
```
|
||||
## 4. Fused MOE kernel
|
||||
To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration
|
||||
|
||||
### Key parameters:
|
||||
- **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers
|
||||
- **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly
|
||||
- **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch
|
||||
- **--dtype**: computation type
|
||||
|
||||
```bash
|
||||
#Tuning
|
||||
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input length 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
|
||||
#so we can tune decode moe use below command
|
||||
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
|
||||
# and use this command to tune prefill moe
|
||||
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768"
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
|
||||
|
||||
[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)
|
||||
380
3rdparty/amd/tuning/benchmark_moe_rocm.py
vendored
Normal file
380
3rdparty/amd/tuning/benchmark_moe_rocm.py
vendored
Normal file
@@ -0,0 +1,380 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_moe,
|
||||
get_config_file_name,
|
||||
)
|
||||
|
||||
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
||||
|
||||
|
||||
def main(model, tp_size, dtype: str, batches):
|
||||
method = fused_moe
|
||||
|
||||
for bs in batches:
|
||||
run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
|
||||
|
||||
|
||||
def prune_configs(M, N, K, configs):
|
||||
pruned_configs = []
|
||||
elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
|
||||
elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
|
||||
|
||||
mfma = 16 if M < 32 or N < 32 else 32
|
||||
|
||||
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||
large_gemm = False
|
||||
if M >= 2048 and N >= 2048:
|
||||
large_gemm = True
|
||||
|
||||
for config in configs:
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
num_warps = config.get("num_warps")
|
||||
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
||||
# kpack = config.get("kpack")
|
||||
if matrix_instr_nonkdim > mfma:
|
||||
continue
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
# number elements per thread is less 1
|
||||
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
SPLIT_K = 1 # config.get("SPLIT_K")
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
|
||||
continue
|
||||
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
||||
continue
|
||||
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
||||
continue
|
||||
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||
# unless BLOCK_SIZE is already small enough
|
||||
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
||||
continue
|
||||
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
||||
continue
|
||||
# skip large split_k when not necessary
|
||||
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
||||
continue
|
||||
# skip split_k that leads to EVEN_K = false
|
||||
leap = SPLIT_K * BLOCK_SIZE_K
|
||||
modv = K % leap
|
||||
if modv != 0:
|
||||
continue
|
||||
# skip large GROUP_M
|
||||
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||
continue
|
||||
# out of shared memory resource
|
||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||
LDS = (
|
||||
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
||||
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
||||
)
|
||||
if LDS > 65536:
|
||||
continue
|
||||
# Skip small block sizes and num_warps for large gemm
|
||||
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
||||
if large_gemm:
|
||||
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
if BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
if num_warps < 4:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
def union_of_list_of_dicts(l1, l2):
|
||||
result = []
|
||||
temp_list = l1.copy()
|
||||
temp_list.extend(l2)
|
||||
for myDict in temp_list:
|
||||
if myDict not in result:
|
||||
result.append(myDict)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_grid(bs, model, method, tp_size, dtype: str):
|
||||
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
|
||||
top_k = config.num_experts_per_tok
|
||||
d_model = config.hidden_size
|
||||
model_intermediate_size = config.intermediate_size
|
||||
num_layers = config.num_hidden_layers
|
||||
hidden_states_dtype = config.torch_dtype
|
||||
|
||||
if config.num_experts_per_tok:
|
||||
if config.architectures[0] == "Grok1ModelForCausalLM":
|
||||
num_total_experts = config.num_experts
|
||||
else:
|
||||
num_total_experts = config.num_local_experts
|
||||
else:
|
||||
raise ValueError(f"Unsupported Mixtral model {model}")
|
||||
|
||||
# tp_size = 2
|
||||
num_warmup_calls = 10
|
||||
num_calls = 30
|
||||
|
||||
num_warmup_trials = 1
|
||||
num_trials = 1
|
||||
|
||||
full_configs = []
|
||||
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [16, 32, 64, 128, 256]
|
||||
block_k_range = [32, 64, 128, 256] # MUST >= 32
|
||||
num_warps_range = [1, 2, 4, 8]
|
||||
group_m_range = [1, 4, 8, 16, 32]
|
||||
# For now we see better perf with num_stages=0 for all gemm configs we care
|
||||
# But keep this explicit so that we do not forget we may need to set it to
|
||||
# other values in the future
|
||||
num_stage_range = [2]
|
||||
waves_per_eu_range = [0, 1, 2, 4, 8]
|
||||
# Remove 32 because of triton compiling error
|
||||
matrix_instr_nonkdim_range = [16]
|
||||
kpack_range = [1, 2]
|
||||
|
||||
for block_size_m in block_m_range:
|
||||
for block_size_n in block_n_range:
|
||||
for block_size_k in block_k_range:
|
||||
for group_size_m in group_m_range:
|
||||
for num_warps in num_warps_range:
|
||||
for num_stages in num_stage_range:
|
||||
for waves_per_eu in waves_per_eu_range:
|
||||
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
|
||||
for kpack in kpack_range:
|
||||
full_configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu,
|
||||
"matrix_instr_nonkdim": matrix_instr_nonkdim,
|
||||
"kpack": kpack,
|
||||
}
|
||||
)
|
||||
|
||||
M1 = bs * 2
|
||||
N1 = model_intermediate_size * 2 // tp_size
|
||||
K1 = d_model
|
||||
prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
|
||||
|
||||
M2 = bs * 2
|
||||
N2 = d_model
|
||||
K2 = model_intermediate_size // tp_size
|
||||
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
|
||||
|
||||
configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
|
||||
|
||||
print(
|
||||
f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
|
||||
{len(prune_configs_2)=} | {len(configs)=}"
|
||||
)
|
||||
|
||||
best_config = None
|
||||
best_time_us = 1e20
|
||||
|
||||
print(f"{tp_size=} {bs=}")
|
||||
|
||||
for config in tqdm(configs):
|
||||
# warmup
|
||||
try:
|
||||
print(config)
|
||||
for _ in range(num_warmup_trials):
|
||||
run_timing(
|
||||
num_calls=num_warmup_calls,
|
||||
bs=bs,
|
||||
d_model=d_model,
|
||||
num_total_experts=num_total_experts,
|
||||
top_k=top_k,
|
||||
tp_size=tp_size,
|
||||
model_intermediate_size=model_intermediate_size,
|
||||
method=method,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
hidden_states_dtype=hidden_states_dtype,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
continue
|
||||
|
||||
# trial
|
||||
for _ in range(num_trials):
|
||||
kernel_dur_ms = run_timing(
|
||||
num_calls=num_calls,
|
||||
bs=bs,
|
||||
d_model=d_model,
|
||||
num_total_experts=num_total_experts,
|
||||
top_k=top_k,
|
||||
tp_size=tp_size,
|
||||
model_intermediate_size=model_intermediate_size,
|
||||
method=method,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
hidden_states_dtype=hidden_states_dtype,
|
||||
)
|
||||
|
||||
kernel_dur_us = 1000 * kernel_dur_ms
|
||||
model_dur_ms = kernel_dur_ms * num_layers
|
||||
|
||||
if kernel_dur_us < best_time_us:
|
||||
best_config = config
|
||||
best_time_us = kernel_dur_us
|
||||
|
||||
tqdm.write(
|
||||
f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
|
||||
f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
|
||||
f"{d_model=} {model_intermediate_size=} {num_layers=}"
|
||||
)
|
||||
|
||||
print("best_time_us", best_time_us)
|
||||
print("best_config", best_config)
|
||||
|
||||
# holds Dict[str, Dict[str, int]]
|
||||
filename = get_config_file_name(
|
||||
num_total_experts,
|
||||
model_intermediate_size // tp_size,
|
||||
"float8" if dtype == "float8" else None,
|
||||
)
|
||||
print(f"writing config to file {filename}")
|
||||
existing_content = {}
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r") as f:
|
||||
existing_content = json.load(f)
|
||||
existing_content[str(bs)] = best_config
|
||||
with open(filename, "w") as f:
|
||||
json.dump(existing_content, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def run_timing(
|
||||
num_calls: int,
|
||||
bs: int,
|
||||
d_model: int,
|
||||
num_total_experts: int,
|
||||
top_k: int,
|
||||
tp_size: int,
|
||||
model_intermediate_size: int,
|
||||
method,
|
||||
config,
|
||||
dtype: str,
|
||||
hidden_states_dtype,
|
||||
) -> float:
|
||||
shard_intermediate_size = model_intermediate_size // tp_size
|
||||
|
||||
hidden_states = torch.rand(
|
||||
(bs, d_model),
|
||||
device="cuda:0",
|
||||
dtype=hidden_states_dtype,
|
||||
)
|
||||
|
||||
w1 = torch.rand(
|
||||
(num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
w2 = torch.rand(
|
||||
(num_total_experts, d_model, shard_intermediate_size + padding_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
if dtype == "float8":
|
||||
w1 = w1.to(torch.float8_e4m3fnuz)
|
||||
w2 = w2.to(torch.float8_e4m3fnuz)
|
||||
w1_scale = torch.ones(
|
||||
num_total_experts, device=hidden_states.device, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.ones(
|
||||
num_total_experts, device=hidden_states.device, dtype=torch.float32
|
||||
)
|
||||
a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
|
||||
a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
|
||||
|
||||
gating_output = F.softmax(
|
||||
torch.rand(
|
||||
(num_calls, bs, num_total_experts),
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
##################################
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for i in range(num_calls):
|
||||
hidden_states = method(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
gating_output=gating_output[0],
|
||||
topk=top_k,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
override_config=config,
|
||||
use_fp8=dtype == "float8",
|
||||
)
|
||||
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
|
||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||
return dur_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="benchmark_mixtral_moe",
|
||||
description="Benchmark and tune the fused_moe kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["float8", "float16", "bfloat16"],
|
||||
help="Data type used for fused_moe kernel computations",
|
||||
)
|
||||
parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
|
||||
|
||||
parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
|
||||
parser.add_argument("-b", "--batches", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
batches = args.batches.split(",")
|
||||
|
||||
sys.exit(main(args.model, args.tp_size, args.dtype, batches))
|
||||
Reference in New Issue
Block a user