adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

118
3rdparty/amd/tuning/TUNING.md vendored Normal file
View File

@@ -0,0 +1,118 @@
## Tuning SGLang Infer System with AMD GPUs
This AppNote describes the SGLang performance tuning technical, code harness and running steps for systems with AMD Instinct GPUs.
Harness code, examples and steps are provided in detail, to facilitate easy reproduce & use to tune performance towards workloads.
Three primary runtime areas are covered:
## 1. Triton Kernels
To maximize Triton kernel efficiency, several strategies can be employed:
### Key Environment Variables:
- **num_stages**: Adjusts the number of pipeline stages to optimize kernel efficiency based on the specific type of operations (e.g., General Matrix Multiplication - GEMM).
- **waves_per_eu**: Controls the usage of Vector General Purpose Registers (VGPR) to enhance occupancy, thereby improving latency or throughput.
- **BLOCK_M, BLOCK_N, BLOCK_K**: Tunable tile sizes that assist in balancing memory transfer and computational efficiency.
- **matrix_instr_nonkdim**: Optimizes the usage of Matrix-Fused Multiply-Add (MFMA) instructions for specific kernel types, such as Flash Attention.
- **OPTIMIZE_EPILOGUE**: An environment variable that can be set to `1` to enhance performance by eliminating the `convert_layout` operation in the kernel's epilogue.
```python
@triton.autotune(configs=[
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
], key=['BLOCK_N', 'NUM_TOKEN_BLKS'], use_cuda_graph=True)
@triton.jit
def _triton_kernel_funtion():
...
```
## 2. Torch Tunable Operations
**TunableOp** is a feature in PyTorch that allows for the definition and optimization of custom kernels with tunable parameters. This feature is particularly useful for enhancing the performance of kernels by experimenting with different configurations.
### Key Environment Variables:
1. **PYTORCH_TUNABLEOP_ENABLED**:
- Default: `0`
- Set to `1` to enable TunableOp.
2. **PYTORCH_TUNABLEOP_TUNING**:
- Default: `1`
- Set to `0` to disable tuning. If a tuned entry is not found, it will run the tuning step and record the entry when PYTORCH_TUNABLEOP_ENABLED is enabled.
3. **PYTORCH_TUNABLEOP_VERBOSE**:
- Default: `0`
- Set to `1` to enable verbose output for TunableOp.
### Usage Example:
To enable TunableOp and tuning, and optionally enable verbose mode, you can run the following command in your terminal:
```bash
#Tuning
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=1 your_script.sh
#Inference with tuning op
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_script.sh
#Print out the log
PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 PYTORCH_TUNABLEOP_VERBOSE=1 your_script.sh
```
## 3. Torch Compilation
The following are suggestions for optimizing matrix multiplication (GEMM) and convolution (conv) operations in PyTorch using Inductor, a part of the PyTorch compilation framework. The goal is to leverage Triton to achieve better performance.
To tune Triton kernels with GEMM and convolution ops (conv), use the `torch.compile` function with the max-autotune mode. This benchmarks a predefined list of Triton configurations and selects the fastest one for each shape.
### Key Configurations:
1. **Max Autotune**:
- Set `torch._inductor.config.max_autotune = True` or `TORCHINDUCTOR_MAX_AUTOTUNE=1`.
2. **Fine-Grained Control**:
- Enable GEMM tuning: `torch._inductor.config.max_autotune_gemm = True`.
- Enable tuning for pointwise/reduction ops: `torch._inductor.config.max_autotune.pointwise = True`.
3. **Backend Selection**:
- Use `torch._inductor.max_autotune_gemm_backends` to limit backends to TRITON for better performance.
4. **Freezing for Inference**:
- Use `torch._inductor.config.freezing=True` to enable constant folding optimizations.
5. **Debugging**:
- Set `TORCH_COMPILE_DEBUG=1` to extract Triton kernels generated by Inductor.
### Example Code Block:
```bash
#Gemm Tuning
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 your_script.sh
#Specify your backend to TRITON for Gemm Tuning
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON your_script.sh
#Inference with large improvement on AMD GPU
TORCHINDUCTOR_FREEZING=1 your_script.sh
```
## 4. Fused MOE kernel
To maximize moe kernel efficiency, need to use below scripts to find out the best launch configuration
### Key parameters:
- **--model**: what moe model type to do tuning, it will automatically decide the size of d_model, model_intermediate_size, num_layers
- **--tp-size**: simulate the whole model run configuration to set the dimension size using tp correctly
- **--batch**: M dimension size of moe kernel, for prefill moe kernel the value is batch*input_len, for decode moe kernel the value is batch
- **--dtype**: computation type
```bash
#Tuning
#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input length 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run).
#so we can tune decode moe use below command
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32"
# and use this command to tune prefill moe
python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32768"
```
## Reference
For more detailed information on tuning SGLang performance with AMD GPUs, please refer to the following link:
[ROCm Documentation: Triton Kernel Performance Optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#triton-kernel-performance-optimization)

View File

@@ -0,0 +1,380 @@
import argparse
import json
import os
import sys
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from tqdm import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_file_name,
)
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
def main(model, tp_size, dtype: str, batches):
method = fused_moe
for bs in batches:
run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
def prune_configs(M, N, K, configs):
pruned_configs = []
elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
mfma = 16 if M < 32 or N < 32 else 32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True
for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
# kpack = config.get("kpack")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = 1 # config.get("SPLIT_K")
GROUP_M = config.get("GROUP_SIZE_M")
if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
continue
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
continue
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue
pruned_configs.append(config)
return pruned_configs
def union_of_list_of_dicts(l1, l2):
result = []
temp_list = l1.copy()
temp_list.extend(l2)
for myDict in temp_list:
if myDict not in result:
result.append(myDict)
return result
def run_grid(bs, model, method, tp_size, dtype: str):
config = AutoConfig.from_pretrained(model)
top_k = config.num_experts_per_tok
d_model = config.hidden_size
model_intermediate_size = config.intermediate_size
num_layers = config.num_hidden_layers
hidden_states_dtype = config.torch_dtype
if config.num_experts_per_tok:
if config.architectures[0] == "Grok1ModelForCausalLM":
num_total_experts = config.num_experts
else:
num_total_experts = config.num_local_experts
else:
raise ValueError(f"Unsupported Mixtral model {model}")
# tp_size = 2
num_warmup_calls = 10
num_calls = 30
num_warmup_trials = 1
num_trials = 1
full_configs = []
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [16, 32, 64, 128, 256]
block_k_range = [32, 64, 128, 256] # MUST >= 32
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [2]
waves_per_eu_range = [0, 1, 2, 4, 8]
# Remove 32 because of triton compiling error
matrix_instr_nonkdim_range = [16]
kpack_range = [1, 2]
for block_size_m in block_m_range:
for block_size_n in block_n_range:
for block_size_k in block_k_range:
for group_size_m in group_m_range:
for num_warps in num_warps_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
full_configs.append(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"kpack": kpack,
}
)
M1 = bs * 2
N1 = model_intermediate_size * 2 // tp_size
K1 = d_model
prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
M2 = bs * 2
N2 = d_model
K2 = model_intermediate_size // tp_size
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
print(
f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
{len(prune_configs_2)=} | {len(configs)=}"
)
best_config = None
best_time_us = 1e20
print(f"{tp_size=} {bs=}")
for config in tqdm(configs):
# warmup
try:
print(config)
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_warmup_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
hidden_states_dtype=hidden_states_dtype,
)
except triton.runtime.autotuner.OutOfResources:
continue
# trial
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
hidden_states_dtype=hidden_states_dtype,
)
kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers
if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us
tqdm.write(
f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
f"{d_model=} {model_intermediate_size=} {num_layers=}"
)
print("best_time_us", best_time_us)
print("best_config", best_config)
# holds Dict[str, Dict[str, int]]
filename = get_config_file_name(
num_total_experts,
model_intermediate_size // tp_size,
"float8" if dtype == "float8" else None,
)
print(f"writing config to file {filename}")
existing_content = {}
if os.path.exists(filename):
with open(filename, "r") as f:
existing_content = json.load(f)
existing_content[str(bs)] = best_config
with open(filename, "w") as f:
json.dump(existing_content, f, indent=4)
f.write("\n")
def run_timing(
num_calls: int,
bs: int,
d_model: int,
num_total_experts: int,
top_k: int,
tp_size: int,
model_intermediate_size: int,
method,
config,
dtype: str,
hidden_states_dtype,
) -> float:
shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=hidden_states_dtype,
)
w1 = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w2 = torch.rand(
(num_total_experts, d_model, shard_intermediate_size + padding_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if dtype == "float8":
w1 = w1.to(torch.float8_e4m3fnuz)
w2 = w2.to(torch.float8_e4m3fnuz)
w1_scale = torch.ones(
num_total_experts, device=hidden_states.device, dtype=torch.float32
)
w2_scale = torch.ones(
num_total_experts, device=hidden_states.device, dtype=torch.float32
)
a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
gating_output = F.softmax(
torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1,
)
##################################
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
gating_output=gating_output[0],
topk=top_k,
renormalize=True,
inplace=True,
override_config=config,
use_fp8=dtype == "float8",
)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="benchmark_mixtral_moe",
description="Benchmark and tune the fused_moe kernel",
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["float8", "float16", "bfloat16"],
help="Data type used for fused_moe kernel computations",
)
parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
parser.add_argument("-b", "--batches", type=str)
args = parser.parse_args()
batches = args.batches.split(",")
sys.exit(main(args.model, args.tp_size, args.dtype, batches))