[8/N] MoE Refactor: deprecate EPMoE (#11211)
This commit is contained in:
@@ -1,29 +0,0 @@
|
|||||||
## Benchmark FBGEMM Grouped GEMM
|
|
||||||
|
|
||||||
Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations.
|
|
||||||
|
|
||||||
### Requirements
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install fbgemm-gpu-genai
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
|
||||||
```
|
|
||||||
|
|
||||||
For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
grouped-gemm-performance:
|
|
||||||
batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8
|
|
||||||
0 256.0 3704.841339 3042.626402 2254.725030
|
|
||||||
1 512.0 3691.426346 3029.065684 2269.504543
|
|
||||||
2 1024.0 3653.938629 2258.471467 2358.319020
|
|
||||||
3 2048.0 3596.644313 2271.611904 2476.895397
|
|
||||||
4 4096.0 3468.496435 2231.283986 2179.473910
|
|
||||||
```
|
|
||||||
|
|
||||||
The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth.
|
|
||||||
@@ -1,516 +0,0 @@
|
|||||||
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
|
|
||||||
quantize_fp8_row,
|
|
||||||
triton_quantize_fp8_row,
|
|
||||||
)
|
|
||||||
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
|
||||||
grouped_gemm as fbgemm_grouped_gemm,
|
|
||||||
)
|
|
||||||
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
|
||||||
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
|
||||||
)
|
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
||||||
grouped_gemm_triton as sglang_grouped_gemm,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(model_name: str, tp_size: int):
|
|
||||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
|
|
||||||
if config.architectures[0] == "DbrxForCausalLM":
|
|
||||||
num_groups = config.ffn_config.moe_num_experts
|
|
||||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
|
||||||
elif config.architectures[0] == "JambaForCausalLM":
|
|
||||||
num_groups = config.num_experts
|
|
||||||
intermediate_size = config.intermediate_size
|
|
||||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
|
||||||
num_groups = config.num_experts
|
|
||||||
intermediate_size = config.moe_intermediate_size
|
|
||||||
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
|
||||||
num_groups = config.num_experts
|
|
||||||
intermediate_size = config.moe_intermediate_size
|
|
||||||
elif config.architectures[0] in [
|
|
||||||
"DeepseekV2ForCausalLM",
|
|
||||||
"DeepseekV3ForCausalLM",
|
|
||||||
]:
|
|
||||||
num_groups = config.n_routed_experts
|
|
||||||
intermediate_size = config.moe_intermediate_size
|
|
||||||
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
|
||||||
num_groups = config.text_config.num_local_experts
|
|
||||||
intermediate_size = config.text_config.intermediate_size
|
|
||||||
elif config.architectures[0] in [
|
|
||||||
"Grok1ForCausalLM",
|
|
||||||
"Grok1ImgGen",
|
|
||||||
"Grok1AForCausalLM",
|
|
||||||
]:
|
|
||||||
num_groups = config.num_local_experts
|
|
||||||
intermediate_size = config.moe_intermediate_size
|
|
||||||
else:
|
|
||||||
num_groups = config.num_local_experts
|
|
||||||
intermediate_size = config.intermediate_size
|
|
||||||
|
|
||||||
shape_configs = {
|
|
||||||
"num_groups": num_groups,
|
|
||||||
"hidden_size": config.hidden_size,
|
|
||||||
"intermediate_size": intermediate_size,
|
|
||||||
"dtype": config.torch_dtype,
|
|
||||||
}
|
|
||||||
print(f"{shape_configs=}")
|
|
||||||
return shape_configs
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
tokens_per_group = batch_size // num_groups
|
|
||||||
m_sizes = torch.full(
|
|
||||||
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
|
|
||||||
|
|
||||||
base_weights = torch.randn(
|
|
||||||
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
|
|
||||||
w_sglang = base_weights
|
|
||||||
|
|
||||||
c_fbgemm = torch.empty(
|
|
||||||
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
|
||||||
)
|
|
||||||
c_sglang = torch.empty(
|
|
||||||
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
|
|
||||||
for i in range(1, num_groups + 1):
|
|
||||||
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
|
|
||||||
|
|
||||||
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
return (
|
|
||||||
x,
|
|
||||||
w_fbgemm,
|
|
||||||
w_sglang,
|
|
||||||
c_fbgemm,
|
|
||||||
c_sglang,
|
|
||||||
m_sizes,
|
|
||||||
seg_indptr,
|
|
||||||
weight_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_fp8_test_data(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create test data for FP8 grouped GEMM operations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_size: Total batch size
|
|
||||||
num_groups: Number of groups
|
|
||||||
hidden_size: Hidden dimension size
|
|
||||||
intermediate_size: Intermediate dimension size
|
|
||||||
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
|
|
||||||
For cutlass: (x, wq, w_scale, m_sizes)
|
|
||||||
"""
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
tokens_per_group = batch_size // num_groups
|
|
||||||
|
|
||||||
# Create weight matrices for each group
|
|
||||||
w_list = []
|
|
||||||
for _ in range(num_groups):
|
|
||||||
w = torch.randn(
|
|
||||||
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
|
|
||||||
)
|
|
||||||
w_list.append(w)
|
|
||||||
|
|
||||||
# Quantize weights using quantize_fp8_row for each group
|
|
||||||
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
|
|
||||||
|
|
||||||
if backend == "triton":
|
|
||||||
# Triton format: concatenated weights
|
|
||||||
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
|
|
||||||
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
|
|
||||||
|
|
||||||
# Create m_sizes as int32 for triton
|
|
||||||
m_sizes = torch.full(
|
|
||||||
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create and quantize input
|
|
||||||
x_fp16 = torch.randn(
|
|
||||||
batch_size, hidden_size, dtype=torch.float16, device="cuda"
|
|
||||||
)
|
|
||||||
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
|
|
||||||
x_scale = x_scale.view(batch_size, -1)
|
|
||||||
|
|
||||||
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
|
|
||||||
|
|
||||||
elif backend == "cutlass":
|
|
||||||
# CUTLASS format: stacked weights
|
|
||||||
wq = torch.stack(wq_list, dim=0).contiguous()
|
|
||||||
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
|
||||||
|
|
||||||
# Create m_sizes as int64 for cutlass
|
|
||||||
m_values = [tokens_per_group] * num_groups
|
|
||||||
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
|
|
||||||
|
|
||||||
# Create input data - separate for each group then concat
|
|
||||||
x_list = []
|
|
||||||
for _ in range(num_groups):
|
|
||||||
x = torch.randn(
|
|
||||||
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
|
|
||||||
)
|
|
||||||
x_list.append(x)
|
|
||||||
|
|
||||||
# Concatenate inputs into single tensor
|
|
||||||
x = torch.concat(x_list, dim=0).contiguous()
|
|
||||||
|
|
||||||
return x, wq, w_scale, m_sizes
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported backend: {backend}")
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
|
|
||||||
"""
|
|
||||||
Calculate memory bandwidth based on accessed expert weights.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
m_sizes: Tensor containing batch sizes for each group
|
|
||||||
hidden_size: Hidden dimension size
|
|
||||||
intermediate_size: Intermediate dimension size
|
|
||||||
dtype: Data type of weights
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Memory size in bytes for accessed expert weights
|
|
||||||
"""
|
|
||||||
# Count non-zero groups (active experts)
|
|
||||||
if hasattr(m_sizes, "cpu"):
|
|
||||||
active_experts = torch.count_nonzero(m_sizes).item()
|
|
||||||
else:
|
|
||||||
active_experts = sum(1 for m in m_sizes if m > 0)
|
|
||||||
|
|
||||||
# Calculate bytes per element based on dtype
|
|
||||||
if dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
bytes_per_element = 2
|
|
||||||
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
||||||
bytes_per_element = 1
|
|
||||||
elif dtype == torch.float32:
|
|
||||||
bytes_per_element = 4
|
|
||||||
else:
|
|
||||||
# Default to 2 bytes for unknown dtypes
|
|
||||||
bytes_per_element = 2
|
|
||||||
|
|
||||||
# Memory per expert weight matrix
|
|
||||||
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
|
|
||||||
|
|
||||||
# Total memory for active experts
|
|
||||||
total_memory_bytes = active_experts * memory_per_expert
|
|
||||||
|
|
||||||
return total_memory_bytes
|
|
||||||
|
|
||||||
|
|
||||||
def get_benchmark_config(use_fp8_w8a8=False):
|
|
||||||
if use_fp8_w8a8:
|
|
||||||
return {
|
|
||||||
"line_vals": [
|
|
||||||
"fbgemm_triton_grouped_gemm_fp8",
|
|
||||||
"fbgemm_cutlass_f8f8bf16_rowwise",
|
|
||||||
"sglang_grouped_gemm",
|
|
||||||
],
|
|
||||||
"line_names": [
|
|
||||||
"FBGEMM Triton Grouped GEMM FP8",
|
|
||||||
"FBGEMM CUTLASS F8F8BF16 Rowwise",
|
|
||||||
"SGLang Grouped GEMM FP8",
|
|
||||||
],
|
|
||||||
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
|
|
||||||
"line_names": [
|
|
||||||
"FBGEMM Triton Grouped GEMM BF16",
|
|
||||||
"SGLang Grouped GEMM BF16",
|
|
||||||
],
|
|
||||||
"styles": [("blue", "-"), ("green", "-")],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(
|
|
||||||
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
|
|
||||||
):
|
|
||||||
config = get_benchmark_config(use_fp8_w8a8)
|
|
||||||
|
|
||||||
benchmark_config = triton.testing.Benchmark(
|
|
||||||
x_names=["batch_size"],
|
|
||||||
x_vals=[256, 512, 1024, 2048, 4096],
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=config["line_vals"],
|
|
||||||
line_names=config["line_names"],
|
|
||||||
styles=config["styles"],
|
|
||||||
ylabel="Bandwidth (GB/s)",
|
|
||||||
plot_name="grouped-gemm-performance",
|
|
||||||
args={},
|
|
||||||
)
|
|
||||||
|
|
||||||
@triton.testing.perf_report(benchmark_config)
|
|
||||||
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
|
||||||
print(f"Benchmarking {provider} with batch_size={batch_size}")
|
|
||||||
torch.cuda.manual_seed_all(0)
|
|
||||||
|
|
||||||
num_groups = model_config["num_groups"]
|
|
||||||
hidden_size = model_config["hidden_size"]
|
|
||||||
intermediate_size = model_config["intermediate_size"]
|
|
||||||
|
|
||||||
if provider == "fbgemm_triton_grouped_gemm_fp8":
|
|
||||||
try:
|
|
||||||
test_data = create_fp8_test_data(
|
|
||||||
batch_size,
|
|
||||||
num_groups,
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
backend="triton",
|
|
||||||
)
|
|
||||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
|
|
||||||
|
|
||||||
# Calculate memory bandwidth
|
|
||||||
memory_bytes = calculate_memory_bandwidth(
|
|
||||||
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_func():
|
|
||||||
return fbgemm_grouped_gemm_fp8_rowwise(
|
|
||||||
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"FP8 not supported, skipping: {e}")
|
|
||||||
return float("inf"), float("inf"), float("inf")
|
|
||||||
|
|
||||||
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
|
|
||||||
try:
|
|
||||||
test_data = create_fp8_test_data(
|
|
||||||
batch_size,
|
|
||||||
num_groups,
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
backend="cutlass",
|
|
||||||
)
|
|
||||||
x, wq, w_scale, m_sizes = test_data
|
|
||||||
|
|
||||||
# Calculate memory bandwidth
|
|
||||||
memory_bytes = calculate_memory_bandwidth(
|
|
||||||
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
# Quantize input using triton_quantize_fp8_row
|
|
||||||
xq, x_scale = triton_quantize_fp8_row(x)
|
|
||||||
x_scale = x_scale.view(batch_size, -1)
|
|
||||||
|
|
||||||
def run_func():
|
|
||||||
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
|
|
||||||
xq, wq, x_scale, w_scale, m_sizes
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(
|
|
||||||
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
|
|
||||||
f"skipping: {e}"
|
|
||||||
)
|
|
||||||
return float("inf"), float("inf"), float("inf")
|
|
||||||
else:
|
|
||||||
test_data = create_test_data(
|
|
||||||
batch_size, num_groups, hidden_size, intermediate_size
|
|
||||||
)
|
|
||||||
(
|
|
||||||
x,
|
|
||||||
w_fbgemm,
|
|
||||||
w_sglang,
|
|
||||||
c_fbgemm,
|
|
||||||
c_sglang,
|
|
||||||
m_sizes,
|
|
||||||
seg_indptr,
|
|
||||||
weight_indices,
|
|
||||||
) = test_data
|
|
||||||
|
|
||||||
# Calculate memory bandwidth for BF16 operations
|
|
||||||
memory_bytes = calculate_memory_bandwidth(
|
|
||||||
m_sizes, hidden_size, intermediate_size, torch.bfloat16
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider == "fbgemm_triton_grouped_gemm":
|
|
||||||
|
|
||||||
def run_func():
|
|
||||||
return fbgemm_grouped_gemm(
|
|
||||||
x, w_fbgemm, m_sizes, use_fast_accum=True
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def run_func():
|
|
||||||
return sglang_grouped_gemm(
|
|
||||||
x,
|
|
||||||
w_sglang,
|
|
||||||
c_sglang,
|
|
||||||
num_groups,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
c_dtype=c_sglang.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in range(10):
|
|
||||||
try:
|
|
||||||
run_func()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during warmup for {provider}: {e}")
|
|
||||||
return float("inf"), float("inf"), float("inf")
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
try:
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
|
|
||||||
|
|
||||||
# Convert time (ms) to bandwidth (GB/s)
|
|
||||||
# Bandwidth = Memory (bytes) / Time (seconds)
|
|
||||||
# Convert ms to seconds and bytes to GB (1e9)
|
|
||||||
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
|
|
||||||
# min bandwidth = max time, max bandwidth = min time
|
|
||||||
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
|
|
||||||
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
|
|
||||||
|
|
||||||
return gb_per_s, min_gb_per_s, max_gb_per_s
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during benchmarking for {provider}: {e}")
|
|
||||||
return 0.0, 0.0, 0.0
|
|
||||||
|
|
||||||
dynamic_benchmark.run(
|
|
||||||
show_plots=True,
|
|
||||||
print_data=True,
|
|
||||||
save_path=save_path,
|
|
||||||
model_config=model_config,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def verify_correctness(model_config):
|
|
||||||
print("Verifying correctness...")
|
|
||||||
batch_size = 128
|
|
||||||
num_groups = model_config["num_groups"]
|
|
||||||
hidden_size = model_config["hidden_size"]
|
|
||||||
intermediate_size = model_config["intermediate_size"]
|
|
||||||
|
|
||||||
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
|
|
||||||
(
|
|
||||||
x,
|
|
||||||
w_fbgemm,
|
|
||||||
w_sglang,
|
|
||||||
c_fbgemm,
|
|
||||||
c_sglang,
|
|
||||||
m_sizes,
|
|
||||||
seg_indptr,
|
|
||||||
weight_indices,
|
|
||||||
) = test_data
|
|
||||||
|
|
||||||
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
|
|
||||||
|
|
||||||
result_sglang = sglang_grouped_gemm(
|
|
||||||
x,
|
|
||||||
w_sglang,
|
|
||||||
c_sglang,
|
|
||||||
num_groups,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
weight_indices=weight_indices,
|
|
||||||
c_dtype=c_sglang.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
|
|
||||||
print("✓ BF16 Correctness verification passed!")
|
|
||||||
else:
|
|
||||||
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
|
|
||||||
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
type=str,
|
|
||||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
help="Model name to get configuration from",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tp-size", type=int, default=1, help="Tensor parallelism size"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save-path",
|
|
||||||
type=str,
|
|
||||||
default="./benchmark_grouped_gemm/",
|
|
||||||
help="Path to save benchmark results",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--verify-correctness",
|
|
||||||
action="store_true",
|
|
||||||
help="Verify correctness before benchmarking",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_config = get_model_config(args.model, args.tp_size)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to get model config: {e}")
|
|
||||||
print("Using default configuration...")
|
|
||||||
model_config = {
|
|
||||||
"num_groups": 8,
|
|
||||||
"hidden_size": 4096,
|
|
||||||
"intermediate_size": 14336,
|
|
||||||
"dtype": torch.bfloat16,
|
|
||||||
}
|
|
||||||
|
|
||||||
print("Running benchmark with:")
|
|
||||||
print(f" num_groups: {model_config['num_groups']}")
|
|
||||||
print(f" hidden_size: {model_config['hidden_size']}")
|
|
||||||
print(f" intermediate_size: {model_config['intermediate_size']}")
|
|
||||||
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
|
|
||||||
|
|
||||||
if args.verify_correctness:
|
|
||||||
if not verify_correctness(model_config):
|
|
||||||
print("Correctness verification failed. Exiting...")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
run_benchmark(
|
|
||||||
model_config=model_config,
|
|
||||||
use_fp8_w8a8=args.use_fp8_w8a8,
|
|
||||||
save_path=args.save_path,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Benchmark failed: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -246,7 +246,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
|-----------|-------------|----------|
|
|-----------|-------------|----------|
|
||||||
| `--ep-size` | The expert parallelism size. | 1 |
|
| `--ep-size` | The expert parallelism size. | 1 |
|
||||||
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
|
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
|
||||||
| `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' |
|
| `--moe-runner-backend` | Select the runner backend for MoE. | auto |
|
||||||
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
|
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
|
||||||
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
|
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
|
||||||
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
|
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
|
||||||
|
|||||||
@@ -13,22 +13,18 @@ from sgl_kernel import (
|
|||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
post_reorder_triton_kernel_for_cutlass_moe,
|
post_reorder_triton_kernel_for_cutlass_moe,
|
||||||
pre_reorder_triton_kernel_for_cutlass_moe,
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||||
run_cutlass_moe_ep_preproess,
|
run_moe_ep_preproess,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_w4a8_moe(
|
def cutlass_w4a8_moe(
|
||||||
start_expert_id: int,
|
|
||||||
end_expert_id: int,
|
|
||||||
total_num_experts: int,
|
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
w1_q: torch.Tensor,
|
w1_q: torch.Tensor,
|
||||||
w2_q: torch.Tensor,
|
w2_q: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids_: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
local_topk_ids: torch.Tensor,
|
|
||||||
a_strides1: torch.Tensor,
|
a_strides1: torch.Tensor,
|
||||||
b_strides1: torch.Tensor,
|
b_strides1: torch.Tensor,
|
||||||
c_strides1: torch.Tensor,
|
c_strides1: torch.Tensor,
|
||||||
@@ -64,6 +60,7 @@ def cutlass_w4a8_moe(
|
|||||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||||
Shape: [num_experts, N // 512, K * 4]
|
Shape: [num_experts, N // 512, K * 4]
|
||||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||||
|
- topk_ids (torch.Tensor): The ids of each token->expert mapping.
|
||||||
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
||||||
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
||||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||||
@@ -83,7 +80,7 @@ def cutlass_w4a8_moe(
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
assert w1_q.dtype == torch.int8
|
assert w1_q.dtype == torch.int8
|
||||||
assert w2_q.dtype == torch.int8
|
assert w2_q.dtype == torch.int8
|
||||||
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
||||||
@@ -96,20 +93,21 @@ def cutlass_w4a8_moe(
|
|||||||
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
||||||
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
||||||
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
||||||
num_experts = w1_q.size(0)
|
num_local_experts = w1_q.size(0)
|
||||||
m = a.size(0)
|
m = a.size(0)
|
||||||
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
||||||
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
||||||
topk = topk_ids_.size(1)
|
topk = topk_ids.size(1)
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
||||||
|
|
||||||
device = a.device
|
device = a.device
|
||||||
|
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
|
||||||
|
|
||||||
_, src2dst, _ = run_cutlass_moe_ep_preproess(
|
_, src2dst, _ = run_moe_ep_preproess(
|
||||||
local_topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_local_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
gateup_input = torch.empty(
|
gateup_input = torch.empty(
|
||||||
@@ -122,9 +120,9 @@ def cutlass_w4a8_moe(
|
|||||||
a,
|
a,
|
||||||
gateup_input,
|
gateup_input,
|
||||||
src2dst,
|
src2dst,
|
||||||
local_topk_ids,
|
topk_ids,
|
||||||
a1_scale,
|
a1_scale,
|
||||||
total_num_experts,
|
num_local_experts,
|
||||||
topk,
|
topk,
|
||||||
k,
|
k,
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
@@ -133,16 +131,16 @@ def cutlass_w4a8_moe(
|
|||||||
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
||||||
# they are kept to allow for a quick switch of the permutation logic
|
# they are kept to allow for a quick switch of the permutation logic
|
||||||
# from the current triton kernel implementation to the cutlass-based one if needed.
|
# from the current triton kernel implementation to the cutlass-based one if needed.
|
||||||
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
get_cutlass_w4a8_moe_mm_data(
|
get_cutlass_w4a8_moe_mm_data(
|
||||||
local_topk_ids,
|
topk_ids,
|
||||||
expert_offsets,
|
expert_offsets,
|
||||||
problem_sizes1,
|
problem_sizes1,
|
||||||
problem_sizes2,
|
problem_sizes2,
|
||||||
a_map,
|
a_map,
|
||||||
c_map,
|
c_map,
|
||||||
num_experts,
|
num_local_experts,
|
||||||
n,
|
n,
|
||||||
k,
|
k,
|
||||||
)
|
)
|
||||||
@@ -195,12 +193,11 @@ def cutlass_w4a8_moe(
|
|||||||
c2,
|
c2,
|
||||||
output,
|
output,
|
||||||
src2dst,
|
src2dst,
|
||||||
local_topk_ids,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
num_experts,
|
|
||||||
topk,
|
topk,
|
||||||
|
num_local_experts,
|
||||||
k,
|
k,
|
||||||
0,
|
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -130,28 +130,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
||||||
expert = tl.program_id(0)
|
expert_id_minus_1 = tl.program_id(0) - 1
|
||||||
low = 0
|
low = 0
|
||||||
high = num_toks - 1
|
high = num_toks - 1
|
||||||
target_location = -1
|
target_location = -1
|
||||||
while low <= high:
|
while low <= high:
|
||||||
mid = (low + high) // 2
|
mid = (low + high) // 2
|
||||||
|
|
||||||
if tl.load(reorder_topk_ids + mid) > expert:
|
if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
|
||||||
high = mid - 1
|
high = mid - 1
|
||||||
else:
|
else:
|
||||||
low = mid + 1
|
low = mid + 1
|
||||||
target_location = mid
|
target_location = mid
|
||||||
tl.store(seg_indptr + expert + 1, target_location + 1)
|
tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)
|
||||||
|
|
||||||
|
|
||||||
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int):
|
||||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||||
|
|
||||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
seg_indptr = torch.zeros(
|
||||||
|
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
|
||||||
|
)
|
||||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||||
|
|
||||||
compute_seg_indptr_triton_kernel[(num_experts,)](
|
compute_seg_indptr_triton_kernel[(num_local_experts,)](
|
||||||
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -164,25 +166,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
|||||||
return reorder_topk_ids, src2dst, seg_indptr
|
return reorder_topk_ids, src2dst, seg_indptr
|
||||||
|
|
||||||
|
|
||||||
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
|
|
||||||
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
|
|
||||||
|
|
||||||
seg_indptr = torch.zeros(
|
|
||||||
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
src2dst = torch.empty(
|
|
||||||
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
BLOCK_SIZE = 512
|
|
||||||
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
|
|
||||||
compute_src2dst_triton_kernel[grid](
|
|
||||||
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
|
|
||||||
)
|
|
||||||
|
|
||||||
return reorder_topk_ids, src2dst, seg_indptr
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def pre_reorder_triton_kernel_for_cutlass_moe(
|
def pre_reorder_triton_kernel_for_cutlass_moe(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
@@ -190,52 +173,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe(
|
|||||||
src2dst_ptr,
|
src2dst_ptr,
|
||||||
topk_ids_ptr,
|
topk_ids_ptr,
|
||||||
a1_scales_ptr,
|
a1_scales_ptr,
|
||||||
num_experts,
|
num_local_experts,
|
||||||
topk,
|
topk,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||||
|
|
||||||
src_idx = tl.program_id(0)
|
|
||||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
|
||||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
|
||||||
|
|
||||||
src_ptr = input_ptr + src_idx * hidden_size
|
|
||||||
for idx in range(topk):
|
|
||||||
expert_id = tl.load(topk_ids_ptr + idx)
|
|
||||||
if expert_id != num_experts:
|
|
||||||
if a1_scales_ptr is not None:
|
|
||||||
scale = 1.0 / tl.load(a1_scales_ptr)
|
|
||||||
else:
|
|
||||||
scale = 1.0
|
|
||||||
|
|
||||||
dst_idx = tl.load(src2dst_ptr + idx)
|
|
||||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
|
||||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
||||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = offset < hidden_size
|
|
||||||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
|
||||||
out_data = (in_data * scale).to(OutDtype)
|
|
||||||
tl.store(dst_ptr + offset, out_data, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def pre_reorder_triton_kernel(
|
|
||||||
input_ptr,
|
|
||||||
gateup_input_ptr,
|
|
||||||
src2dst_ptr,
|
|
||||||
topk_ids_ptr,
|
|
||||||
a1_scales_ptr,
|
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
topk,
|
|
||||||
hidden_size,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
use_per_token_if_dynamic: tl.constexpr,
|
|
||||||
):
|
|
||||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
|
||||||
|
|
||||||
src_idx_int32 = tl.program_id(0)
|
src_idx_int32 = tl.program_id(0)
|
||||||
src_idx = src_idx_int32.to(tl.int64)
|
src_idx = src_idx_int32.to(tl.int64)
|
||||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||||
@@ -244,15 +188,11 @@ def pre_reorder_triton_kernel(
|
|||||||
|
|
||||||
vec = tl.arange(0, BLOCK_SIZE)
|
vec = tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
if a1_scales_ptr is not None and use_per_token_if_dynamic:
|
|
||||||
scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
|
|
||||||
|
|
||||||
for idx in range(topk):
|
for idx in range(topk):
|
||||||
expert_id = tl.load(topk_ids_ptr + idx)
|
expert_id = tl.load(topk_ids_ptr + idx)
|
||||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
if expert_id != num_local_experts:
|
||||||
if a1_scales_ptr is not None:
|
if a1_scales_ptr is not None:
|
||||||
if not use_per_token_if_dynamic:
|
scale = 1.0 / tl.load(a1_scales_ptr)
|
||||||
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
|
||||||
else:
|
else:
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
@@ -267,52 +207,6 @@ def pre_reorder_triton_kernel(
|
|||||||
tl.store(dst_ptr + offset, out_data, mask=mask)
|
tl.store(dst_ptr + offset, out_data, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def silu_and_mul_triton_kernel(
|
|
||||||
gateup_output,
|
|
||||||
down_input,
|
|
||||||
hidden_size,
|
|
||||||
reorder_topk_ids,
|
|
||||||
scales,
|
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
InDtype = gateup_output.dtype.element_ty
|
|
||||||
OutDtype = down_input.dtype.element_ty
|
|
||||||
|
|
||||||
half_hidden_size = hidden_size // 2
|
|
||||||
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
expert_id = tl.load(reorder_topk_ids + pid)
|
|
||||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
|
||||||
gateup_output_ptr = gateup_output + pid * hidden_size
|
|
||||||
gate_output_ptr = gateup_output_ptr
|
|
||||||
up_output_ptr = gateup_output_ptr + half_hidden_size
|
|
||||||
down_input_ptr = down_input + pid * half_hidden_size
|
|
||||||
|
|
||||||
if scales is not None:
|
|
||||||
scale = tl.load(scales + expert_id - start_expert_id)
|
|
||||||
scale = (1 / scale).to(InDtype)
|
|
||||||
else:
|
|
||||||
scale = 1
|
|
||||||
|
|
||||||
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
|
||||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = offset < half_hidden_size
|
|
||||||
|
|
||||||
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
|
||||||
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
|
||||||
|
|
||||||
# silu & mul & quantize
|
|
||||||
gate_output = gate_output * tl.sigmoid(gate_output)
|
|
||||||
gate_output = gate_output.to(InDtype)
|
|
||||||
|
|
||||||
silu_mul_output = gate_output * up_output * scale
|
|
||||||
silu_mul_output = silu_mul_output.to(OutDtype)
|
|
||||||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _silu_and_mul_post_quant_kernel(
|
def _silu_and_mul_post_quant_kernel(
|
||||||
@@ -461,70 +355,44 @@ def silu_and_mul_masked_post_quant_fwd(
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def tanh(x):
|
def post_reorder_triton_kernel_for_cutlass_moe(
|
||||||
return 2 * tl.sigmoid(2 * x) - 1
|
down_output_ptr,
|
||||||
|
output_ptr,
|
||||||
|
src2dst_ptr,
|
||||||
@triton.jit
|
topk_ids_ptr,
|
||||||
def gelu_and_mul_triton_kernel(
|
topk_weights_ptr,
|
||||||
gateup_output,
|
topk,
|
||||||
down_input,
|
num_local_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
reorder_topk_ids,
|
|
||||||
scales,
|
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
InDtype = gateup_output.dtype.element_ty
|
InDtype = down_output_ptr.dtype.element_ty
|
||||||
OutDtype = down_input.dtype.element_ty
|
|
||||||
|
|
||||||
half_hidden_size = hidden_size // 2
|
src_idx_int32 = tl.program_id(0)
|
||||||
|
src_idx = src_idx_int32.to(tl.int64)
|
||||||
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||||
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||||
|
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
||||||
|
|
||||||
pid = tl.program_id(0)
|
store_ptr = output_ptr + src_idx * hidden_size
|
||||||
expert_id = tl.load(reorder_topk_ids + pid)
|
|
||||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
|
||||||
gateup_output_ptr = gateup_output + pid * hidden_size
|
|
||||||
gate_output_ptr = gateup_output_ptr
|
|
||||||
up_output_ptr = gateup_output_ptr + half_hidden_size
|
|
||||||
down_input_ptr = down_input + pid * half_hidden_size
|
|
||||||
|
|
||||||
if scales is not None:
|
vec = tl.arange(0, BLOCK_SIZE)
|
||||||
scale = tl.load(scales + expert_id - start_expert_id)
|
|
||||||
scale = (1 / scale).to(InDtype)
|
|
||||||
else:
|
|
||||||
scale = 1
|
|
||||||
|
|
||||||
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
offset = start_offset + vec
|
||||||
mask = offset < half_hidden_size
|
mask = offset < hidden_size
|
||||||
|
|
||||||
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||||
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
for idx in range(topk):
|
||||||
|
expert_id = tl.load(topk_ids_ptr + idx)
|
||||||
# gelu & mul & quantize
|
if expert_id != num_local_experts:
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||||
# sqrt(2/pi)
|
dst_idx = dst_idx_int32.to(tl.int64)
|
||||||
kAlpha = 0.7978845608028654
|
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||||
gate_output = (
|
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||||
0.5
|
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||||
* gate_output
|
sum_vec += in_data * weigh_scale
|
||||||
* (
|
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||||
1
|
|
||||||
+ tanh(
|
|
||||||
kAlpha
|
|
||||||
* (
|
|
||||||
gate_output
|
|
||||||
+ 0.044715 * gate_output * gate_output * gate_output
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
gate_output = gate_output.to(InDtype)
|
|
||||||
|
|
||||||
gelu_mul_output = gate_output * up_output * scale
|
|
||||||
gelu_mul_output = gelu_mul_output.to(OutDtype)
|
|
||||||
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -534,64 +402,8 @@ def post_reorder_triton_kernel(
|
|||||||
src2dst_ptr,
|
src2dst_ptr,
|
||||||
topk_ids_ptr,
|
topk_ids_ptr,
|
||||||
topk_weights_ptr,
|
topk_weights_ptr,
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
topk,
|
topk,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
dst_start,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
InDtype = down_output_ptr.dtype.element_ty
|
|
||||||
|
|
||||||
src_idx_int32 = tl.program_id(0)
|
|
||||||
src_idx = src_idx_int32.to(tl.int64)
|
|
||||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
|
||||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
|
||||||
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
|
||||||
|
|
||||||
computed = False
|
|
||||||
store_ptr = output_ptr + src_idx * hidden_size
|
|
||||||
|
|
||||||
vec = tl.arange(0, BLOCK_SIZE)
|
|
||||||
|
|
||||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
||||||
offset = start_offset + vec
|
|
||||||
mask = offset < hidden_size
|
|
||||||
|
|
||||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
|
||||||
for idx in range(topk):
|
|
||||||
expert_id = tl.load(topk_ids_ptr + idx)
|
|
||||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
|
||||||
computed = True
|
|
||||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
|
||||||
dst_idx = dst_idx_int32.to(tl.int64)
|
|
||||||
dst_idx = dst_idx - dst_start
|
|
||||||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
|
||||||
load_ptr = down_output_ptr + dst_idx * hidden_size
|
|
||||||
in_data = tl.load(load_ptr + offset, mask=mask)
|
|
||||||
sum_vec += in_data * weigh_scale
|
|
||||||
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
|
||||||
|
|
||||||
if computed == False:
|
|
||||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
||||||
offset = start_offset + vec
|
|
||||||
mask = offset < hidden_size
|
|
||||||
tl.store(
|
|
||||||
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def post_reorder_triton_kernel_for_cutlass_moe(
|
|
||||||
down_output_ptr,
|
|
||||||
output_ptr,
|
|
||||||
src2dst_ptr,
|
|
||||||
topk_ids_ptr,
|
|
||||||
topk_weights_ptr,
|
|
||||||
num_experts,
|
|
||||||
topk,
|
|
||||||
hidden_size,
|
|
||||||
dst_start,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
InDtype = down_output_ptr.dtype.element_ty
|
InDtype = down_output_ptr.dtype.element_ty
|
||||||
@@ -613,10 +425,9 @@ def post_reorder_triton_kernel_for_cutlass_moe(
|
|||||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||||
for idx in range(topk):
|
for idx in range(topk):
|
||||||
expert_id = tl.load(topk_ids_ptr + idx)
|
expert_id = tl.load(topk_ids_ptr + idx)
|
||||||
if expert_id != num_experts:
|
if expert_id > 0:
|
||||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||||
dst_idx = dst_idx_int32.to(tl.int64)
|
dst_idx = dst_idx_int32.to(tl.int64)
|
||||||
dst_idx = dst_idx - dst_start
|
|
||||||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||||
load_ptr = down_output_ptr + dst_idx * hidden_size
|
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||||
in_data = tl.load(load_ptr + offset, mask=mask)
|
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||||
@@ -624,232 +435,6 @@ def post_reorder_triton_kernel_for_cutlass_moe(
|
|||||||
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def compute_m_range(
|
|
||||||
pid,
|
|
||||||
batch_size,
|
|
||||||
seg_indptr,
|
|
||||||
weight_indices,
|
|
||||||
m_num_tiles_indptr,
|
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
|
||||||
):
|
|
||||||
idx = 0
|
|
||||||
for bs in range(batch_size):
|
|
||||||
tiles = tl.load(m_num_tiles_indptr + bs)
|
|
||||||
if pid >= tiles:
|
|
||||||
idx = bs
|
|
||||||
|
|
||||||
idx_start = tl.load(m_num_tiles_indptr + idx)
|
|
||||||
|
|
||||||
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
|
|
||||||
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
|
|
||||||
expert_id = tl.load(weight_indices + idx)
|
|
||||||
return m_range_start, m_range_end, expert_id
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def grouped_gemm_triton_kernel(
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
c,
|
|
||||||
batch_size,
|
|
||||||
N,
|
|
||||||
K,
|
|
||||||
seg_indptr,
|
|
||||||
weight_indices,
|
|
||||||
m_num_tiles_indptr,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
use_fp8_w8a8: tl.constexpr,
|
|
||||||
group_n: tl.constexpr,
|
|
||||||
group_k: tl.constexpr,
|
|
||||||
a_stride_0: tl.constexpr,
|
|
||||||
b_stride_0: tl.constexpr,
|
|
||||||
b_stride_1: tl.constexpr,
|
|
||||||
as_stride_0: tl.constexpr,
|
|
||||||
as_stride_1: tl.constexpr,
|
|
||||||
bs_stride_0: tl.constexpr,
|
|
||||||
bs_stride_2: tl.constexpr,
|
|
||||||
bs_stride_1: tl.constexpr,
|
|
||||||
use_per_token_if_dynamic: tl.constexpr,
|
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
|
||||||
):
|
|
||||||
c_dtype = c.dtype.element_ty
|
|
||||||
|
|
||||||
pid_m = tl.program_id(0)
|
|
||||||
pid_n = tl.program_id(1)
|
|
||||||
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
|
|
||||||
if pid_m >= total_m_block:
|
|
||||||
return
|
|
||||||
|
|
||||||
m_range_start, m_range_end, expert_id = compute_m_range(
|
|
||||||
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
|
|
||||||
)
|
|
||||||
if m_range_end - m_range_start == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
n_range_start = pid_n * BLOCK_SIZE_N
|
|
||||||
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
|
|
||||||
|
|
||||||
offs_am = tl.arange(0, BLOCK_SIZE_M)
|
|
||||||
offs_bn = tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
|
|
||||||
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
|
|
||||||
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
|
|
||||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
|
||||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
|
|
||||||
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
|
|
||||||
b_ptr = b + (
|
|
||||||
(expert_id * b_stride_0)
|
|
||||||
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
|
|
||||||
if group_k > 0 and group_n > 0:
|
|
||||||
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
|
|
||||||
offs_bsn = (n_range_start + offs_bn) // group_n
|
|
||||||
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
|
|
||||||
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
||||||
a_tile = tl.load(
|
|
||||||
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
|
||||||
)
|
|
||||||
b_tile = tl.load(
|
|
||||||
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
if group_k > 0 and group_n > 0:
|
|
||||||
k_start = k * BLOCK_SIZE_K
|
|
||||||
offs_ks = k_start // group_k
|
|
||||||
a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
|
|
||||||
b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
|
|
||||||
accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
|
|
||||||
else:
|
|
||||||
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
|
||||||
a_ptr += BLOCK_SIZE_K
|
|
||||||
b_ptr += BLOCK_SIZE_K
|
|
||||||
|
|
||||||
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
|
||||||
if use_per_token_if_dynamic:
|
|
||||||
scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
|
|
||||||
else:
|
|
||||||
scale_a_value = tl.load(scale_a + expert_id)
|
|
||||||
scale_b_value = tl.load(scale_b + expert_id)
|
|
||||||
accumulator *= scale_a_value * scale_b_value
|
|
||||||
|
|
||||||
c_tile = accumulator.to(c_dtype)
|
|
||||||
|
|
||||||
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
|
||||||
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
|
|
||||||
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
|
|
||||||
tl.store(c_ptr, c_tile, mask=c_mask)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def compute_m_num_tiles_indptr(
|
|
||||||
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
|
|
||||||
):
|
|
||||||
for bs in range(batch_size):
|
|
||||||
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
|
|
||||||
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
|
||||||
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
|
|
||||||
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
|
|
||||||
|
|
||||||
|
|
||||||
def grouped_gemm_triton(
|
|
||||||
a: torch.Tensor,
|
|
||||||
b: torch.Tensor,
|
|
||||||
c: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
weight_column_major: bool,
|
|
||||||
seg_indptr: Optional[torch.Tensor] = None,
|
|
||||||
weight_indices: Optional[torch.Tensor] = None,
|
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
scale_a: torch.Tensor = None,
|
|
||||||
scale_b: torch.Tensor = None,
|
|
||||||
block_shape: Optional[List[int]] = None,
|
|
||||||
c_dtype=None,
|
|
||||||
use_per_token_if_dynamic: bool = True,
|
|
||||||
):
|
|
||||||
assert weight_column_major == True # TODO: more
|
|
||||||
if use_fp8_w8a8 and block_shape is None:
|
|
||||||
assert scale_a is not None and scale_b is not None
|
|
||||||
|
|
||||||
if block_shape is not None:
|
|
||||||
a_original = a
|
|
||||||
|
|
||||||
assert len(block_shape) == 2
|
|
||||||
block_n, block_k = block_shape[0], block_shape[1]
|
|
||||||
a, scale_a = per_token_group_quant_fp8(a, block_k)
|
|
||||||
|
|
||||||
assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
|
|
||||||
assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
|
|
||||||
assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
|
|
||||||
|
|
||||||
dispose_tensor(a_original)
|
|
||||||
|
|
||||||
# TODO: adjust config or tune kernel
|
|
||||||
# Reduce block size to prevent L40 shared memory overflow.
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
}
|
|
||||||
|
|
||||||
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
|
|
||||||
compute_m_num_tiles_indptr[(1,)](
|
|
||||||
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if c is None:
|
|
||||||
assert c_dtype is not None
|
|
||||||
c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
|
|
||||||
|
|
||||||
grid = lambda META: (
|
|
||||||
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
|
||||||
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
|
|
||||||
assert (
|
|
||||||
scale_a.shape[0] == a.shape[0]
|
|
||||||
), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
|
|
||||||
|
|
||||||
grouped_gemm_triton_kernel[grid](
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
c,
|
|
||||||
batch_size,
|
|
||||||
b.size(1),
|
|
||||||
b.size(2),
|
|
||||||
seg_indptr,
|
|
||||||
weight_indices,
|
|
||||||
m_num_tiles_indptr,
|
|
||||||
scale_a,
|
|
||||||
scale_b,
|
|
||||||
use_fp8_w8a8,
|
|
||||||
0 if block_shape is None else block_shape[0],
|
|
||||||
0 if block_shape is None else block_shape[1],
|
|
||||||
a.stride(0),
|
|
||||||
b.stride(0),
|
|
||||||
b.stride(1),
|
|
||||||
scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
|
|
||||||
scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
|
|
||||||
scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
|
||||||
scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
|
|
||||||
scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
|
|
||||||
use_per_token_if_dynamic,
|
|
||||||
**config,
|
|
||||||
)
|
|
||||||
return c
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel_ep_scatter_1(
|
def _fwd_kernel_ep_scatter_1(
|
||||||
num_recv_tokens_per_expert,
|
num_recv_tokens_per_expert,
|
||||||
@@ -1234,7 +819,7 @@ def deepgemm_compute_src2dst_triton_kernel(
|
|||||||
mask = dst_id < num_toks
|
mask = dst_id < num_toks
|
||||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||||
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
|
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
|
||||||
expert_dst_start = tl.load(seg_indptr + expert_id)
|
expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
|
||||||
expert_dst_offset = dst_id - expert_dst_start
|
expert_dst_offset = dst_id - expert_dst_start
|
||||||
dst_id = expert_id * m_max + expert_dst_offset
|
dst_id = expert_id * m_max + expert_dst_offset
|
||||||
tl.store(src2dst + src_id, dst_id, mask=mask)
|
tl.store(src2dst + src_id, dst_id, mask=mask)
|
||||||
@@ -1248,10 +833,7 @@ def fill_gateup_input_triton_kernel(
|
|||||||
gateup_input_scale_ptr,
|
gateup_input_scale_ptr,
|
||||||
src2dst_ptr,
|
src2dst_ptr,
|
||||||
topk_ids_ptr,
|
topk_ids_ptr,
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
topk,
|
topk,
|
||||||
m_max,
|
|
||||||
hidden_size,
|
hidden_size,
|
||||||
scale_size,
|
scale_size,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
@@ -1267,10 +849,9 @@ def fill_gateup_input_triton_kernel(
|
|||||||
vec = tl.arange(0, BLOCK_SIZE)
|
vec = tl.arange(0, BLOCK_SIZE)
|
||||||
for idx in range(topk):
|
for idx in range(topk):
|
||||||
expert_id = tl.load(topk_ids_ptr + idx)
|
expert_id = tl.load(topk_ids_ptr + idx)
|
||||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
if expert_id >= 0:
|
||||||
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||||
dst_idx = dst_idx_int32.to(tl.int64)
|
dst_idx = dst_idx_int32.to(tl.int64)
|
||||||
dst_idx = dst_idx - start_expert_id * m_max
|
|
||||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||||
offset = start_offset + vec
|
offset = start_offset + vec
|
||||||
@@ -1287,31 +868,31 @@ def fill_gateup_input_triton_kernel(
|
|||||||
|
|
||||||
def moe_ep_deepgemm_preprocess(
|
def moe_ep_deepgemm_preprocess(
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_local_experts: int,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
block_shape,
|
block_shape,
|
||||||
output_dtype: torch.dtype = torch.float8_e4m3fn,
|
output_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
):
|
):
|
||||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
seg_indptr = torch.zeros(
|
||||||
|
num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
|
||||||
|
)
|
||||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||||
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
|
masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)
|
||||||
|
|
||||||
compute_seg_indptr_triton_kernel[(num_experts,)](
|
compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
|
||||||
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
||||||
)
|
)
|
||||||
|
|
||||||
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
|
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
|
||||||
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
|
compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)
|
||||||
|
|
||||||
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
|
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
|
||||||
m_max = (hidden_states.size(0) + 255) // 256 * 256
|
m_max = (hidden_states.size(0) // 256 + 1) * 256
|
||||||
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
|
expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
|
||||||
gateup_input = torch.empty(
|
gateup_input = torch.empty(
|
||||||
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
|
(num_local_experts, m_max, hidden_states.size(1)),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=output_dtype,
|
dtype=output_dtype,
|
||||||
)
|
)
|
||||||
@@ -1330,6 +911,8 @@ def moe_ep_deepgemm_preprocess(
|
|||||||
block_shape = [128, 128]
|
block_shape = [128, 128]
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
block_n, block_k = block_shape[0], block_shape[1]
|
block_n, block_k = block_shape[0], block_shape[1]
|
||||||
|
|
||||||
|
# TODO: fuse this with the preprocess
|
||||||
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
|
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
|
||||||
|
|
||||||
gateup_input_scale = torch.empty(
|
gateup_input_scale = torch.empty(
|
||||||
@@ -1345,18 +928,14 @@ def moe_ep_deepgemm_preprocess(
|
|||||||
gateup_input_scale,
|
gateup_input_scale,
|
||||||
src2dst,
|
src2dst,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
top_k,
|
top_k,
|
||||||
m_max,
|
|
||||||
hidden_states.size(1),
|
hidden_states.size(1),
|
||||||
scale.size(1),
|
scale.size(1),
|
||||||
BLOCK_SIZE=1024,
|
BLOCK_SIZE=1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
m_max,
|
masked_m,
|
||||||
masked_m[start_expert_id : (end_expert_id + 1)],
|
|
||||||
expected_m,
|
expected_m,
|
||||||
src2dst,
|
src2dst,
|
||||||
gateup_input,
|
gateup_input,
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from contextlib import nullcontext
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
|
||||||
from sglang.srt.layers.moe import (
|
from sglang.srt.layers.moe import (
|
||||||
get_deepep_mode,
|
get_deepep_mode,
|
||||||
get_moe_a2a_backend,
|
get_moe_a2a_backend,
|
||||||
@@ -18,13 +14,10 @@ from sglang.srt.layers.moe import (
|
|||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
ep_gather,
|
ep_gather,
|
||||||
ep_scatter,
|
ep_scatter,
|
||||||
moe_ep_deepgemm_preprocess,
|
|
||||||
post_reorder_triton_kernel,
|
|
||||||
silu_and_mul_masked_post_quant_fwd,
|
silu_and_mul_masked_post_quant_fwd,
|
||||||
tma_align_input_scale,
|
tma_align_input_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||||
@@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|||||||
CUTEDSL_MOE_NVFP4_DISPATCH,
|
CUTEDSL_MOE_NVFP4_DISPATCH,
|
||||||
ModelOptNvFp4FusedMoEMethod,
|
ModelOptNvFp4FusedMoEMethod,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.offloader import get_offloader
|
from sglang.srt.offloader import get_offloader
|
||||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||||
ceil_div,
|
|
||||||
dispose_tensor,
|
|
||||||
get_bool_env_var,
|
|
||||||
get_int_env_var,
|
|
||||||
is_cuda,
|
|
||||||
is_hip,
|
|
||||||
is_npu,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
@@ -72,275 +56,7 @@ if _use_aiter:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
class DeepEPMoE(FusedMoE):
|
||||||
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
|
||||||
@torch.compile
|
|
||||||
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
temp = x.to(torch.float32).view(torch.int32)
|
|
||||||
exp = torch.bitwise_right_shift(temp, 23)
|
|
||||||
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
|
||||||
is_ru = torch.logical_and(
|
|
||||||
torch.logical_and((mant > 0), (exp != 0xFE)),
|
|
||||||
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
|
||||||
)
|
|
||||||
exp = torch.where(is_ru, exp + 1, exp)
|
|
||||||
new_x = exp.to(torch.uint8).view(torch.int)
|
|
||||||
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class EPMoE(FusedMoE):
|
|
||||||
"""
|
|
||||||
MoE Expert Parallel Impl
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_experts: int,
|
|
||||||
top_k: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
layer_id: int,
|
|
||||||
num_fused_shared_experts: int = 0,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
activation: str = "silu",
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
gemm1_alpha: Optional[float] = None,
|
|
||||||
gemm1_clamp_limit: Optional[float] = None,
|
|
||||||
with_bias: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
num_experts=num_experts,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
num_fused_shared_experts=num_fused_shared_experts,
|
|
||||||
layer_id=layer_id,
|
|
||||||
top_k=top_k,
|
|
||||||
params_dtype=params_dtype,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=prefix,
|
|
||||||
activation=activation,
|
|
||||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
gemm1_alpha=gemm1_alpha,
|
|
||||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
|
||||||
with_bias=with_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
|
|
||||||
if isinstance(quant_config, Fp8Config):
|
|
||||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
|
||||||
self.block_shape = (
|
|
||||||
self.quant_method.quant_config.weight_block_size
|
|
||||||
if self.use_block_quant
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
self.use_fp8_w8a8 = True
|
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
|
||||||
self.activation_scheme = quant_config.activation_scheme
|
|
||||||
else:
|
|
||||||
self.use_fp8_w8a8 = False
|
|
||||||
self.use_block_quant = False
|
|
||||||
self.block_shape = None
|
|
||||||
self.activation_scheme = None
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
|
||||||
return self.forward_deepgemm(hidden_states, topk_output)
|
|
||||||
else:
|
|
||||||
return super().forward(hidden_states, topk_output)
|
|
||||||
|
|
||||||
def forward_deepgemm(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
topk_output: TopKOutput,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.w13_weight_fp8 = (
|
|
||||||
self.w13_weight,
|
|
||||||
(
|
|
||||||
self.w13_weight_scale_inv
|
|
||||||
if self.use_block_quant
|
|
||||||
else self.w13_weight_scale
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.w2_weight_fp8 = (
|
|
||||||
self.w2_weight,
|
|
||||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.quant_method is not None
|
|
||||||
assert self.moe_runner_config.activation == "silu"
|
|
||||||
|
|
||||||
hidden_states_shape = hidden_states.shape
|
|
||||||
hidden_states_dtype = hidden_states.dtype
|
|
||||||
hidden_states_device = hidden_states.device
|
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
|
||||||
|
|
||||||
if not self.use_block_quant:
|
|
||||||
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
|
||||||
scale_block_size = 128
|
|
||||||
w13_weight_scale_n = 2 * (
|
|
||||||
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
|
||||||
)
|
|
||||||
w13_weight_scale_k = (
|
|
||||||
hidden_states_shape[-1] + scale_block_size - 1
|
|
||||||
) // scale_block_size
|
|
||||||
w13_weight_scale = (
|
|
||||||
self.w13_weight_scale.unsqueeze(1)
|
|
||||||
.repeat_interleave(w13_weight_scale_n, dim=1)
|
|
||||||
.unsqueeze(2)
|
|
||||||
.repeat_interleave(w13_weight_scale_k, dim=2)
|
|
||||||
)
|
|
||||||
self.w13_weight_fp8 = (
|
|
||||||
self.w13_weight,
|
|
||||||
w13_weight_scale,
|
|
||||||
)
|
|
||||||
w2_weight_scale_n = (
|
|
||||||
hidden_states_shape[-1] + scale_block_size - 1
|
|
||||||
) // scale_block_size
|
|
||||||
w2_weight_scale_k = (
|
|
||||||
self.intermediate_size + scale_block_size - 1
|
|
||||||
) // scale_block_size
|
|
||||||
w2_weight_scale = (
|
|
||||||
self.w2_weight_scale.unsqueeze(1)
|
|
||||||
.repeat_interleave(w2_weight_scale_n, dim=1)
|
|
||||||
.unsqueeze(2)
|
|
||||||
.repeat_interleave(w2_weight_scale_k, dim=2)
|
|
||||||
)
|
|
||||||
self.w2_weight_fp8 = (
|
|
||||||
self.w2_weight,
|
|
||||||
w2_weight_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# PreReorder
|
|
||||||
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
|
||||||
moe_ep_deepgemm_preprocess(
|
|
||||||
topk_ids,
|
|
||||||
self.num_experts,
|
|
||||||
hidden_states,
|
|
||||||
self.top_k,
|
|
||||||
self.start_expert_id,
|
|
||||||
self.end_expert_id,
|
|
||||||
self.block_shape,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
dispose_tensor(hidden_states)
|
|
||||||
|
|
||||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
|
||||||
b, s_mn, s_k = gateup_input_scale.shape
|
|
||||||
assert (
|
|
||||||
s_mn % 4 == 0 and s_k % 4 == 0
|
|
||||||
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
|
||||||
|
|
||||||
# GroupGemm-0
|
|
||||||
gateup_input_fp8 = (
|
|
||||||
gateup_input,
|
|
||||||
(
|
|
||||||
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
|
||||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
|
||||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
|
||||||
gateup_input_scale
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
num_groups, m, k = gateup_input_fp8[0].size()
|
|
||||||
n = self.w13_weight.size(1)
|
|
||||||
gateup_output = torch.empty(
|
|
||||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
||||||
)
|
|
||||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
||||||
gateup_input_fp8,
|
|
||||||
self.w13_weight_fp8,
|
|
||||||
gateup_output,
|
|
||||||
masked_m,
|
|
||||||
expected_m,
|
|
||||||
)
|
|
||||||
del gateup_input
|
|
||||||
del gateup_input_fp8
|
|
||||||
|
|
||||||
# Act
|
|
||||||
down_input = torch.empty(
|
|
||||||
(
|
|
||||||
gateup_output.shape[0],
|
|
||||||
gateup_output.shape[1],
|
|
||||||
gateup_output.shape[2] // 2,
|
|
||||||
),
|
|
||||||
device=hidden_states_device,
|
|
||||||
dtype=self.fp8_dtype,
|
|
||||||
)
|
|
||||||
scale_block_size = 128
|
|
||||||
down_input_scale = torch.empty(
|
|
||||||
(
|
|
||||||
gateup_output.shape[0],
|
|
||||||
gateup_output.shape[1],
|
|
||||||
gateup_output.shape[2] // 2 // scale_block_size,
|
|
||||||
),
|
|
||||||
device=hidden_states_device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
silu_and_mul_masked_post_quant_fwd(
|
|
||||||
gateup_output,
|
|
||||||
down_input,
|
|
||||||
down_input_scale,
|
|
||||||
scale_block_size,
|
|
||||||
masked_m,
|
|
||||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
|
||||||
)
|
|
||||||
del gateup_output
|
|
||||||
|
|
||||||
# GroupGemm-1
|
|
||||||
n = self.w2_weight.size(1)
|
|
||||||
down_input_fp8 = (
|
|
||||||
down_input,
|
|
||||||
(
|
|
||||||
down_input_scale
|
|
||||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
|
||||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
down_output = torch.empty(
|
|
||||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
|
||||||
)
|
|
||||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
|
||||||
down_input_fp8,
|
|
||||||
self.w2_weight_fp8,
|
|
||||||
down_output,
|
|
||||||
masked_m,
|
|
||||||
expected_m,
|
|
||||||
)
|
|
||||||
del down_input
|
|
||||||
del down_input_fp8
|
|
||||||
|
|
||||||
# PostReorder
|
|
||||||
output = torch.empty(
|
|
||||||
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
|
||||||
)
|
|
||||||
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
|
||||||
down_output,
|
|
||||||
output,
|
|
||||||
src2dst,
|
|
||||||
topk_ids,
|
|
||||||
topk_weights,
|
|
||||||
self.start_expert_id,
|
|
||||||
self.end_expert_id,
|
|
||||||
self.top_k,
|
|
||||||
hidden_states_shape[1],
|
|
||||||
m_max * self.start_expert_id,
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
|
||||||
output *= self.moe_runner_config.routed_scaling_factor
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class DeepEPMoE(EPMoE):
|
|
||||||
"""
|
"""
|
||||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||||
"""
|
"""
|
||||||
@@ -374,6 +90,15 @@ class DeepEPMoE(EPMoE):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(quant_config, Fp8Config):
|
||||||
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
|
self.use_fp8_w8a8 = True
|
||||||
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
|
else:
|
||||||
|
self.use_fp8_w8a8 = False
|
||||||
|
self.use_block_quant = False
|
||||||
|
|
||||||
self.deepep_mode = get_deepep_mode()
|
self.deepep_mode = get_deepep_mode()
|
||||||
|
|
||||||
# TODO: move to the beginning of the file
|
# TODO: move to the beginning of the file
|
||||||
@@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE):
|
|||||||
N = self.w13_weight.size(1)
|
N = self.w13_weight.size(1)
|
||||||
scale_block_size = 128
|
scale_block_size = 128
|
||||||
|
|
||||||
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
|
|
||||||
w13_weight_fp8 = (
|
w13_weight_fp8 = (
|
||||||
self.w13_weight,
|
self.w13_weight,
|
||||||
(
|
(
|
||||||
@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
|
|||||||
return FlashInferFusedMoE
|
return FlashInferFusedMoE
|
||||||
if get_moe_runner_backend().is_flashinfer_cutlass():
|
if get_moe_runner_backend().is_flashinfer_cutlass():
|
||||||
return FusedMoE
|
return FusedMoE
|
||||||
if get_moe_expert_parallel_world_size() > 1:
|
|
||||||
return EPMoE
|
|
||||||
return FusedMoE
|
return FusedMoE
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
||||||
assert num_experts % self.moe_ep_size == 0
|
assert num_experts % self.moe_ep_size == 0
|
||||||
self.num_local_experts = num_experts // self.moe_ep_size
|
self.num_local_experts = num_experts // self.moe_ep_size
|
||||||
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
|
||||||
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
|
||||||
if self.moe_ep_size > 1:
|
if self.moe_ep_size > 1:
|
||||||
# TODO(ch-wan): support shared experts fusion
|
# TODO(ch-wan): support shared experts fusion
|
||||||
# Create a tensor of size num_experts filled with -1
|
# Create a tensor of size num_experts filled with -1
|
||||||
|
|||||||
304
python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
Normal file
304
python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.moe_runner.base import (
|
||||||
|
MoeQuantInfo,
|
||||||
|
MoeRunnerConfig,
|
||||||
|
MoeRunnerCore,
|
||||||
|
RunnerInput,
|
||||||
|
RunnerOutput,
|
||||||
|
register_post_permute,
|
||||||
|
register_pre_permute,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
||||||
|
from sglang.srt.utils import dispose_tensor
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||||
|
StandardCombineInput,
|
||||||
|
StandardDispatchOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
||||||
|
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
||||||
|
@torch.compile
|
||||||
|
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
temp = x.to(torch.float32).view(torch.int32)
|
||||||
|
exp = torch.bitwise_right_shift(temp, 23)
|
||||||
|
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
||||||
|
is_ru = torch.logical_and(
|
||||||
|
torch.logical_and((mant > 0), (exp != 0xFE)),
|
||||||
|
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
||||||
|
)
|
||||||
|
exp = torch.where(is_ru, exp + 1, exp)
|
||||||
|
new_x = exp.to(torch.uint8).view(torch.int)
|
||||||
|
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepGemmRunnerInput(RunnerInput):
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
hidden_states_scale: torch.Tensor
|
||||||
|
masked_m: torch.Tensor
|
||||||
|
expected_m: int
|
||||||
|
use_masked_gemm: bool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runner_backend(self) -> MoeRunnerBackend:
|
||||||
|
return MoeRunnerBackend.DEEP_GEMM
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepGemmRunnerOutput(RunnerOutput):
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runner_backend(self) -> MoeRunnerBackend:
|
||||||
|
return MoeRunnerBackend.DEEP_GEMM
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepGemmMoeQuantInfo(MoeQuantInfo):
|
||||||
|
w13_weight: torch.Tensor
|
||||||
|
w2_weight: torch.Tensor
|
||||||
|
use_fp8: bool
|
||||||
|
w13_scale: Optional[torch.Tensor] = None
|
||||||
|
w2_scale: Optional[torch.Tensor] = None
|
||||||
|
block_shape: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeepGemmRunnerCore(MoeRunnerCore):
|
||||||
|
def __init__(self, config: MoeRunnerConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
assert self.config.activation == "silu"
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
runner_input: DeepGemmRunnerInput,
|
||||||
|
quant_info: DeepGemmMoeQuantInfo,
|
||||||
|
running_state: dict,
|
||||||
|
) -> DeepGemmRunnerOutput:
|
||||||
|
|
||||||
|
if runner_input.use_masked_gemm:
|
||||||
|
hidden_states = self._run_masked_gemm(
|
||||||
|
runner_input,
|
||||||
|
quant_info,
|
||||||
|
running_state,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = self._run_contiguous_gemm(
|
||||||
|
runner_input,
|
||||||
|
quant_info,
|
||||||
|
running_state,
|
||||||
|
)
|
||||||
|
return DeepGemmRunnerOutput(hidden_states=hidden_states)
|
||||||
|
|
||||||
|
def _run_masked_gemm(
|
||||||
|
self,
|
||||||
|
runner_input: DeepGemmRunnerInput,
|
||||||
|
quant_info: DeepGemmMoeQuantInfo,
|
||||||
|
running_state: dict,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
|
silu_and_mul_masked_post_quant_fwd,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
|
|
||||||
|
hidden_states = runner_input.hidden_states
|
||||||
|
hidden_states_scale = runner_input.hidden_states_scale
|
||||||
|
masked_m = runner_input.masked_m
|
||||||
|
expected_m = runner_input.expected_m
|
||||||
|
|
||||||
|
w13_weight = quant_info.w13_weight
|
||||||
|
w2_weight = quant_info.w2_weight
|
||||||
|
w13_scale = quant_info.w13_scale
|
||||||
|
w2_scale = quant_info.w2_scale
|
||||||
|
|
||||||
|
hidden_states_device = running_state["hidden_states_device"]
|
||||||
|
|
||||||
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||||
|
b, s_mn, s_k = hidden_states_scale.shape
|
||||||
|
assert (
|
||||||
|
s_mn % 4 == 0 and s_k % 4 == 0
|
||||||
|
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
||||||
|
|
||||||
|
# GroupGemm-0
|
||||||
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||||
|
hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
|
||||||
|
else:
|
||||||
|
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||||
|
hidden_states_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
num_groups, m, k = hidden_states.shape
|
||||||
|
n = w13_weight.size(1)
|
||||||
|
gateup_output = torch.empty(
|
||||||
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||||
|
(hidden_states, hidden_states_scale),
|
||||||
|
(w13_weight, w13_scale),
|
||||||
|
gateup_output,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
)
|
||||||
|
dispose_tensor(hidden_states)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
down_input = torch.empty(
|
||||||
|
(
|
||||||
|
gateup_output.shape[0],
|
||||||
|
gateup_output.shape[1],
|
||||||
|
gateup_output.shape[2] // 2,
|
||||||
|
),
|
||||||
|
device=hidden_states_device,
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
)
|
||||||
|
scale_block_size = 128
|
||||||
|
down_input_scale = torch.empty(
|
||||||
|
(
|
||||||
|
gateup_output.shape[0],
|
||||||
|
gateup_output.shape[1],
|
||||||
|
gateup_output.shape[2] // 2 // scale_block_size,
|
||||||
|
),
|
||||||
|
device=hidden_states_device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
silu_and_mul_masked_post_quant_fwd(
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
down_input_scale,
|
||||||
|
scale_block_size,
|
||||||
|
masked_m,
|
||||||
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||||
|
)
|
||||||
|
del gateup_output
|
||||||
|
|
||||||
|
# GroupGemm-1
|
||||||
|
n = w2_weight.shape[1]
|
||||||
|
|
||||||
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||||
|
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||||
|
down_input_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
down_output = torch.empty(
|
||||||
|
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||||
|
(down_input, down_input_scale),
|
||||||
|
(w2_weight, w2_scale),
|
||||||
|
down_output,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
)
|
||||||
|
del down_input
|
||||||
|
|
||||||
|
return down_output
|
||||||
|
|
||||||
|
def _run_contiguous_gemm(
|
||||||
|
self,
|
||||||
|
runner_input: DeepGemmRunnerInput,
|
||||||
|
quant_info: DeepGemmMoeQuantInfo,
|
||||||
|
running_state: dict,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runner_backend(self) -> MoeRunnerBackend:
|
||||||
|
return MoeRunnerBackend.DEEP_GEMM
|
||||||
|
|
||||||
|
|
||||||
|
@register_pre_permute("standard", "deep_gemm")
|
||||||
|
def pre_permute_standard_to_deep_gemm(
|
||||||
|
dispatch_output: StandardDispatchOutput,
|
||||||
|
quant_info: DeepGemmMoeQuantInfo,
|
||||||
|
runner_config: MoeRunnerConfig,
|
||||||
|
running_state: dict,
|
||||||
|
) -> DeepGemmRunnerInput:
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
|
||||||
|
|
||||||
|
hidden_states, topk_output = dispatch_output
|
||||||
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
|
hidden_states_shape = hidden_states.shape
|
||||||
|
hidden_states_dtype = hidden_states.dtype
|
||||||
|
hidden_states_device = hidden_states.device
|
||||||
|
hidden_states_ref = hidden_states
|
||||||
|
|
||||||
|
topk_weights, topk_ids = topk_weights, topk_ids
|
||||||
|
|
||||||
|
# PreReorder
|
||||||
|
masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
|
||||||
|
moe_ep_deepgemm_preprocess(
|
||||||
|
topk_ids,
|
||||||
|
runner_config.num_local_experts,
|
||||||
|
hidden_states,
|
||||||
|
runner_config.top_k,
|
||||||
|
quant_info.block_shape,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dispose_tensor(hidden_states_ref)
|
||||||
|
|
||||||
|
running_state["topk_ids"] = topk_ids
|
||||||
|
running_state["topk_weights"] = topk_weights
|
||||||
|
running_state["hidden_states_shape"] = hidden_states_shape
|
||||||
|
running_state["hidden_states_dtype"] = hidden_states_dtype
|
||||||
|
running_state["hidden_states_device"] = hidden_states_device
|
||||||
|
running_state["src2dst"] = src2dst
|
||||||
|
|
||||||
|
return DeepGemmRunnerInput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
hidden_states_scale=hidden_states_scale,
|
||||||
|
masked_m=masked_m,
|
||||||
|
expected_m=expected_m,
|
||||||
|
use_masked_gemm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_post_permute("deep_gemm", "standard")
|
||||||
|
def post_permute_deep_gemm_to_standard(
|
||||||
|
runner_output: DeepGemmRunnerOutput,
|
||||||
|
quant_info: DeepGemmMoeQuantInfo,
|
||||||
|
runner_config: MoeRunnerConfig,
|
||||||
|
running_state: dict,
|
||||||
|
) -> StandardCombineInput:
|
||||||
|
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
|
||||||
|
|
||||||
|
hidden_states_shape = running_state["hidden_states_shape"]
|
||||||
|
hidden_states_dtype = running_state["hidden_states_dtype"]
|
||||||
|
hidden_states_device = running_state["hidden_states_device"]
|
||||||
|
src2dst = running_state["src2dst"]
|
||||||
|
topk_ids = running_state["topk_ids"]
|
||||||
|
topk_weights = running_state["topk_weights"]
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
||||||
|
)
|
||||||
|
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
||||||
|
runner_output.hidden_states,
|
||||||
|
output,
|
||||||
|
src2dst,
|
||||||
|
topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
runner_config.top_k,
|
||||||
|
hidden_states_shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
dispose_tensor(runner_output.hidden_states)
|
||||||
|
|
||||||
|
if runner_config.routed_scaling_factor is not None:
|
||||||
|
output *= runner_config.routed_scaling_factor
|
||||||
|
|
||||||
|
return StandardCombineInput(
|
||||||
|
hidden_states=output,
|
||||||
|
)
|
||||||
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
|
|||||||
MoeRunnerConfig,
|
MoeRunnerConfig,
|
||||||
PermuteMethodPool,
|
PermuteMethodPool,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
|
||||||
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
||||||
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
||||||
|
|
||||||
@@ -30,6 +31,8 @@ class MoeRunner:
|
|||||||
|
|
||||||
if runner_backend.is_triton():
|
if runner_backend.is_triton():
|
||||||
self.runner_core = TritonRunnerCore(config)
|
self.runner_core = TritonRunnerCore(config)
|
||||||
|
elif runner_backend.is_deep_gemm():
|
||||||
|
self.runner_core = DeepGemmRunnerCore(config)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class MoeA2ABackend(Enum):
|
|||||||
class MoeRunnerBackend(Enum):
|
class MoeRunnerBackend(Enum):
|
||||||
|
|
||||||
AUTO = "auto"
|
AUTO = "auto"
|
||||||
|
DEEP_GEMM = "deep_gemm"
|
||||||
TRITON = "triton"
|
TRITON = "triton"
|
||||||
TRITON_KERNEL = "triton_kernel"
|
TRITON_KERNEL = "triton_kernel"
|
||||||
FLASHINFER_TRTLLM = "flashinfer_trtllm"
|
FLASHINFER_TRTLLM = "flashinfer_trtllm"
|
||||||
@@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum):
|
|||||||
def is_auto(self):
|
def is_auto(self):
|
||||||
return self == MoeRunnerBackend.AUTO
|
return self == MoeRunnerBackend.AUTO
|
||||||
|
|
||||||
|
def is_deep_gemm(self):
|
||||||
|
return self == MoeRunnerBackend.DEEP_GEMM
|
||||||
|
|
||||||
def is_triton(self):
|
def is_triton(self):
|
||||||
return self == MoeRunnerBackend.TRITON
|
return self == MoeRunnerBackend.TRITON
|
||||||
|
|
||||||
@@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
|
|||||||
def get_moe_runner_backend() -> MoeRunnerBackend:
|
def get_moe_runner_backend() -> MoeRunnerBackend:
|
||||||
global MOE_RUNNER_BACKEND
|
global MOE_RUNNER_BACKEND
|
||||||
if MOE_RUNNER_BACKEND is None:
|
if MOE_RUNNER_BACKEND is None:
|
||||||
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
logger.warning(
|
||||||
|
"MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected"
|
||||||
|
)
|
||||||
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
||||||
return MOE_RUNNER_BACKEND
|
return MOE_RUNNER_BACKEND
|
||||||
|
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ except ImportError:
|
|||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
|
||||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
@@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
def create_moe_runner(
|
def create_moe_runner(
|
||||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||||
):
|
):
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.utils import (
|
||||||
|
get_moe_a2a_backend,
|
||||||
|
get_moe_runner_backend,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
|
|
||||||
self.moe_runner_config = moe_runner_config
|
self.moe_runner_config = moe_runner_config
|
||||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
moe_runner_backend = get_moe_runner_backend()
|
||||||
|
|
||||||
|
if moe_runner_backend.is_auto():
|
||||||
|
if (
|
||||||
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
|
and get_moe_a2a_backend().is_deepep()
|
||||||
|
):
|
||||||
|
moe_runner_backend = MoeRunnerBackend.DEEP_GEMM
|
||||||
|
else:
|
||||||
|
moe_runner_backend = MoeRunnerBackend.TRITON
|
||||||
|
if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton():
|
||||||
|
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
|
||||||
|
else:
|
||||||
|
# TODO(cwan): refactor other backends
|
||||||
|
pass
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -1087,6 +1108,44 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
return StandardCombineInput(hidden_states=output)
|
return StandardCombineInput(hidden_states=output)
|
||||||
|
|
||||||
|
if self.runner.runner_backend.is_deep_gemm():
|
||||||
|
|
||||||
|
w13_weight = layer.w13_weight
|
||||||
|
w2_weight = layer.w2_weight
|
||||||
|
|
||||||
|
if self.block_quant:
|
||||||
|
block_shape = self.quant_config.weight_block_size
|
||||||
|
w13_scale = layer.w13_weight_scale_inv
|
||||||
|
w2_scale = layer.w2_weight_scale_inv
|
||||||
|
else:
|
||||||
|
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
||||||
|
scale_block_size = 128
|
||||||
|
block_shape = [scale_block_size, scale_block_size]
|
||||||
|
w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1
|
||||||
|
w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1
|
||||||
|
w13_scale = (
|
||||||
|
layer.w13_weight_scale.unsqueeze(1)
|
||||||
|
.repeat_interleave(w13_scale_n, dim=1)
|
||||||
|
.unsqueeze(2)
|
||||||
|
.repeat_interleave(w13_scale_k, dim=2)
|
||||||
|
)
|
||||||
|
w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1
|
||||||
|
w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1
|
||||||
|
w2_scale = (
|
||||||
|
layer.w2_weight_scale.unsqueeze(1)
|
||||||
|
.repeat_interleave(w2_scale_n, dim=1)
|
||||||
|
.unsqueeze(2)
|
||||||
|
.repeat_interleave(w2_scale_k, dim=2)
|
||||||
|
)
|
||||||
|
quant_info = DeepGemmMoeQuantInfo(
|
||||||
|
w13_weight=w13_weight,
|
||||||
|
w2_weight=w2_weight,
|
||||||
|
use_fp8=True,
|
||||||
|
w13_scale=w13_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
elif self.runner.runner_backend.is_triton():
|
||||||
quant_info = TritonMoeQuantInfo(
|
quant_info = TritonMoeQuantInfo(
|
||||||
w13_weight=layer.w13_weight,
|
w13_weight=layer.w13_weight,
|
||||||
w2_weight=layer.w2_weight,
|
w2_weight=layer.w2_weight,
|
||||||
@@ -1097,12 +1156,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
else layer.w13_weight_scale
|
else layer.w13_weight_scale
|
||||||
),
|
),
|
||||||
w2_scale=(
|
w2_scale=(
|
||||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
layer.w2_weight_scale_inv
|
||||||
|
if self.block_quant
|
||||||
|
else layer.w2_weight_scale
|
||||||
),
|
),
|
||||||
a13_scale=layer.w13_input_scale,
|
a13_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Unsupported runner backend: %s" % self.runner.runner_backend
|
||||||
|
)
|
||||||
|
|
||||||
return self.runner.run(dispatch_output, quant_info)
|
return self.runner.run(dispatch_output, quant_info)
|
||||||
|
|
||||||
def apply_with_router_logits(
|
def apply_with_router_logits(
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
CombineInput,
|
CombineInput,
|
||||||
StandardDispatchOutput,
|
StandardDispatchOutput,
|
||||||
@@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig):
|
|||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional[QuantizeMethodBase]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignored_layers):
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
@@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: EPMoE,
|
layer: Module,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
@@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: EPMoE,
|
layer: Module,
|
||||||
dispatch_output: StandardDispatchOutput,
|
dispatch_output: StandardDispatchOutput,
|
||||||
) -> CombineInput:
|
) -> CombineInput:
|
||||||
|
|
||||||
@@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_output = dispatch_output.topk_output
|
topk_output = dispatch_output.topk_output
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
local_topk_ids = topk_ids
|
|
||||||
if get_moe_expert_parallel_world_size() > 1:
|
|
||||||
local_topk_ids = torch.where(
|
|
||||||
topk_ids == -1,
|
|
||||||
layer.num_experts,
|
|
||||||
topk_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = cutlass_w4a8_moe(
|
output = cutlass_w4a8_moe(
|
||||||
layer.start_expert_id,
|
|
||||||
layer.end_expert_id,
|
|
||||||
layer.num_experts,
|
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
@@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight_scale_inv,
|
layer.w2_weight_scale_inv,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
local_topk_ids,
|
|
||||||
self.a_strides1,
|
self.a_strides1,
|
||||||
self.b_strides1,
|
self.b_strides1,
|
||||||
self.c_strides1,
|
self.c_strides1,
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = {}
|
self.experts = FusedMoE(
|
||||||
if get_moe_expert_parallel_world_size() > 1:
|
|
||||||
MoEImpl = EPMoE
|
|
||||||
else:
|
|
||||||
MoEImpl = FusedMoE
|
|
||||||
kwargs["reduce_results"] = reduce_results
|
|
||||||
kwargs["use_presharded_weights"] = use_presharded_weights
|
|
||||||
kwargs["inplace"] = inplace
|
|
||||||
kwargs["no_combine"] = no_combine
|
|
||||||
|
|
||||||
self.experts = MoEImpl(
|
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
|
|||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
activation="gelu",
|
activation="gelu",
|
||||||
**kwargs,
|
reduce_results=reduce_results,
|
||||||
|
use_presharded_weights=use_presharded_weights,
|
||||||
|
inplace=inplace,
|
||||||
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
|
|||||||
renormalize=True,
|
renormalize=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
|
self.experts = FusedMoE(
|
||||||
self.experts = MoEImpl(
|
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
|||||||
@@ -121,6 +121,17 @@ NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"
|
|||||||
|
|
||||||
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
||||||
|
|
||||||
|
MOE_RUNNER_BACKEND_CHOICES = [
|
||||||
|
"auto",
|
||||||
|
"deep_gemm",
|
||||||
|
"triton",
|
||||||
|
"triton_kernel",
|
||||||
|
"flashinfer_trtllm",
|
||||||
|
"flashinfer_cutlass",
|
||||||
|
"flashinfer_mxfp4",
|
||||||
|
"flashinfer_cutedsl",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# Allow external code to add more choices
|
# Allow external code to add more choices
|
||||||
def add_load_format_choices(choices):
|
def add_load_format_choices(choices):
|
||||||
@@ -143,6 +154,10 @@ def add_grammar_backend_choices(choices):
|
|||||||
GRAMMAR_BACKEND_CHOICES.extend(choices)
|
GRAMMAR_BACKEND_CHOICES.extend(choices)
|
||||||
|
|
||||||
|
|
||||||
|
def add_moe_runner_backend_choices(choices):
|
||||||
|
MOE_RUNNER_BACKEND_CHOICES.extend(choices)
|
||||||
|
|
||||||
|
|
||||||
def add_deterministic_attention_backend_choices(choices):
|
def add_deterministic_attention_backend_choices(choices):
|
||||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
|
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
|
||||||
|
|
||||||
@@ -315,14 +330,7 @@ class ServerArgs:
|
|||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
ep_size: int = 1
|
ep_size: int = 1
|
||||||
moe_a2a_backend: Literal["none", "deepep"] = "none"
|
moe_a2a_backend: Literal["none", "deepep"] = "none"
|
||||||
moe_runner_backend: Literal[
|
moe_runner_backend: str = "auto"
|
||||||
"auto",
|
|
||||||
"triton",
|
|
||||||
"triton_kernel",
|
|
||||||
"flashinfer_trtllm",
|
|
||||||
"flashinfer_cutlass",
|
|
||||||
"flashinfer_mxfp4",
|
|
||||||
] = "auto"
|
|
||||||
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
|
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
|
||||||
enable_flashinfer_allreduce_fusion: bool = False
|
enable_flashinfer_allreduce_fusion: bool = False
|
||||||
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
||||||
@@ -2191,15 +2199,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--moe-runner-backend",
|
"--moe-runner-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=[
|
choices=MOE_RUNNER_BACKEND_CHOICES,
|
||||||
"auto",
|
|
||||||
"triton",
|
|
||||||
"triton_kernel",
|
|
||||||
"flashinfer_trtllm",
|
|
||||||
"flashinfer_cutlass",
|
|
||||||
"flashinfer_mxfp4",
|
|
||||||
"flashinfer_cutedsl",
|
|
||||||
],
|
|
||||||
default=ServerArgs.moe_runner_backend,
|
default=ServerArgs.moe_runner_backend,
|
||||||
help="Choose the runner backend for MoE.",
|
help="Choose the runner backend for MoE.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,358 +0,0 @@
|
|||||||
import itertools
|
|
||||||
import random
|
|
||||||
import unittest
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
||||||
grouped_gemm_triton,
|
|
||||||
post_reorder_triton_kernel,
|
|
||||||
pre_reorder_triton_kernel,
|
|
||||||
run_moe_ep_preproess,
|
|
||||||
silu_and_mul_triton_kernel,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
|
|
||||||
from sglang.test.test_utils import CustomTestCase
|
|
||||||
|
|
||||||
|
|
||||||
# For test
|
|
||||||
def ep_moe(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
topk_config: TopKConfig,
|
|
||||||
# ep config
|
|
||||||
num_experts: int = 256,
|
|
||||||
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
|
||||||
num_experts_per_partition: int = 128,
|
|
||||||
start_expert_id: int = 0,
|
|
||||||
end_expert_id: int = 127,
|
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
w1_scale_inv: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale_inv: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[List[int]] = None,
|
|
||||||
):
|
|
||||||
use_blockwise_fp8 = block_shape is not None
|
|
||||||
top_k = topk_config.top_k
|
|
||||||
topk_output = select_experts(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
router_logits=router_logits,
|
|
||||||
topk_config=topk_config,
|
|
||||||
)
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
|
||||||
|
|
||||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
|
||||||
|
|
||||||
gateup_input = torch.empty(
|
|
||||||
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=(
|
|
||||||
fp8_dtype
|
|
||||||
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
|
||||||
else hidden_states.dtype
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_fp8_w8a8 and not use_blockwise_fp8:
|
|
||||||
max_value = (
|
|
||||||
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
|
|
||||||
)
|
|
||||||
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
|
|
||||||
else:
|
|
||||||
w1_input_scale = None
|
|
||||||
|
|
||||||
# PreReorder
|
|
||||||
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
|
||||||
hidden_states,
|
|
||||||
gateup_input,
|
|
||||||
src2dst,
|
|
||||||
topk_ids,
|
|
||||||
w1_input_scale,
|
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
top_k,
|
|
||||||
hidden_states.shape[1],
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
use_per_token_if_dynamic=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
|
|
||||||
weight_indices_cur_rank = torch.arange(
|
|
||||||
0,
|
|
||||||
num_experts_per_partition,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# GroupGemm-0
|
|
||||||
gateup_output = torch.empty(
|
|
||||||
gateup_input.shape[0],
|
|
||||||
w1.shape[1],
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
gateup_output = grouped_gemm_triton(
|
|
||||||
a=gateup_input,
|
|
||||||
b=w1,
|
|
||||||
c=gateup_output,
|
|
||||||
batch_size=num_experts_per_partition,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
|
||||||
weight_indices=weight_indices_cur_rank,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
scale_a=w1_input_scale,
|
|
||||||
scale_b=w1_scale_inv,
|
|
||||||
block_shape=block_shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Act
|
|
||||||
down_input = torch.empty(
|
|
||||||
gateup_output.shape[0],
|
|
||||||
gateup_output.shape[1] // 2,
|
|
||||||
device=gateup_output.device,
|
|
||||||
dtype=(
|
|
||||||
fp8_dtype
|
|
||||||
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
|
||||||
else hidden_states.dtype
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if use_fp8_w8a8 and not use_blockwise_fp8:
|
|
||||||
w2_input_scale = torch.ones(
|
|
||||||
num_experts_per_partition,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=hidden_states.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
w2_input_scale = None
|
|
||||||
|
|
||||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
||||||
gateup_output,
|
|
||||||
down_input,
|
|
||||||
gateup_output.shape[1],
|
|
||||||
reorder_topk_ids,
|
|
||||||
w2_input_scale,
|
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
|
|
||||||
# GroupGemm-1
|
|
||||||
down_output = torch.empty(
|
|
||||||
down_input.shape[0],
|
|
||||||
w2.shape[1],
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
down_output = grouped_gemm_triton(
|
|
||||||
a=down_input,
|
|
||||||
b=w2,
|
|
||||||
c=down_output,
|
|
||||||
batch_size=num_experts_per_partition,
|
|
||||||
weight_column_major=True,
|
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
|
||||||
weight_indices=weight_indices_cur_rank,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
scale_a=w2_input_scale,
|
|
||||||
scale_b=w2_scale_inv,
|
|
||||||
block_shape=block_shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
# PostReorder
|
|
||||||
output = torch.empty_like(hidden_states)
|
|
||||||
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
|
||||||
down_output,
|
|
||||||
output,
|
|
||||||
src2dst,
|
|
||||||
topk_ids,
|
|
||||||
topk_weights,
|
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
top_k,
|
|
||||||
hidden_states.size(1),
|
|
||||||
0,
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
# test util
|
|
||||||
def block_dequant(
|
|
||||||
x_q_block: torch.Tensor,
|
|
||||||
x_s: torch.Tensor,
|
|
||||||
block_size: List[int],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""This function converts block-wise quantization to tensor-wise quantization.
|
|
||||||
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
|
||||||
and the block size.
|
|
||||||
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
|
||||||
Note only float8 is supported for now.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# process 3D tensor
|
|
||||||
if x_q_block.dim() == 3:
|
|
||||||
batch_size = x_q_block.size(0)
|
|
||||||
return torch.stack(
|
|
||||||
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
|
|
||||||
)
|
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
n, k = x_q_block.shape
|
|
||||||
n_tiles = (n + block_n - 1) // block_n
|
|
||||||
k_tiles = (k + block_k - 1) // block_k
|
|
||||||
assert n_tiles == x_s.shape[0]
|
|
||||||
assert k_tiles == x_s.shape[1]
|
|
||||||
|
|
||||||
x_dq_block = x_q_block.to(torch.float32)
|
|
||||||
|
|
||||||
x_dq_block_tiles = [
|
|
||||||
[
|
|
||||||
x_dq_block[
|
|
||||||
j * block_n : min((j + 1) * block_n, n),
|
|
||||||
i * block_k : min((i + 1) * block_k, k),
|
|
||||||
]
|
|
||||||
for i in range(k_tiles)
|
|
||||||
]
|
|
||||||
for j in range(n_tiles)
|
|
||||||
]
|
|
||||||
|
|
||||||
for i in range(k_tiles):
|
|
||||||
for j in range(n_tiles):
|
|
||||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
|
||||||
|
|
||||||
return x_dq_block
|
|
||||||
|
|
||||||
|
|
||||||
class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
|
||||||
DTYPES = [torch.half, torch.bfloat16]
|
|
||||||
M = [1, 222, 1024, 2048]
|
|
||||||
N = [128, 1024, 2048]
|
|
||||||
K = [256, 4096, 5120]
|
|
||||||
E = [8, 16]
|
|
||||||
ep_size = [2, 4]
|
|
||||||
TOP_KS = [2, 4]
|
|
||||||
BLOCK_SIZE = [[128, 128]]
|
|
||||||
SEEDS = [0]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise unittest.SkipTest("CUDA is not available")
|
|
||||||
torch.set_default_device("cuda")
|
|
||||||
|
|
||||||
def _w8a8_block_fp8_ep_moe(
|
|
||||||
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
|
|
||||||
):
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
|
||||||
factor_for_scale = 1e-2
|
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
|
||||||
|
|
||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
|
||||||
|
|
||||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
|
|
||||||
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
|
|
||||||
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
|
||||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
|
||||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
|
||||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
|
||||||
|
|
||||||
w1_s = (
|
|
||||||
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
|
||||||
* factor_for_scale
|
|
||||||
)
|
|
||||||
w2_s = (
|
|
||||||
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
|
||||||
* factor_for_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
|
|
||||||
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
|
|
||||||
|
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
|
||||||
num_experts_per_partition = E // ep_size
|
|
||||||
cur_rank = random.randint(0, ep_size - 1)
|
|
||||||
start_id = cur_rank * num_experts_per_partition
|
|
||||||
end_id = start_id + num_experts_per_partition - 1
|
|
||||||
|
|
||||||
topk_config = TopKConfig(
|
|
||||||
top_k=topk,
|
|
||||||
renormalize=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
out = ep_moe(
|
|
||||||
hidden_states=a,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
router_logits=score,
|
|
||||||
topk_config=topk_config,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
w1_scale_inv=w1_s,
|
|
||||||
w2_scale_inv=w2_s,
|
|
||||||
block_shape=block_size,
|
|
||||||
num_experts=E,
|
|
||||||
num_experts_per_partition=num_experts_per_partition,
|
|
||||||
start_expert_id=start_id,
|
|
||||||
end_expert_id=end_id,
|
|
||||||
)
|
|
||||||
ref_out = ep_moe(
|
|
||||||
hidden_states=a,
|
|
||||||
w1=w1_ref,
|
|
||||||
w2=w2_ref,
|
|
||||||
router_logits=score,
|
|
||||||
topk_config=topk_config,
|
|
||||||
use_fp8_w8a8=False,
|
|
||||||
w1_scale_inv=None,
|
|
||||||
w2_scale_inv=None,
|
|
||||||
block_shape=None,
|
|
||||||
num_experts=E,
|
|
||||||
num_experts_per_partition=num_experts_per_partition,
|
|
||||||
start_expert_id=start_id,
|
|
||||||
end_expert_id=end_id,
|
|
||||||
)
|
|
||||||
self.assertTrue(
|
|
||||||
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
|
||||||
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
|
|
||||||
< 0.06
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_w8a8_block_fp8_ep_moe(self):
|
|
||||||
for params in itertools.product(
|
|
||||||
self.M,
|
|
||||||
self.N,
|
|
||||||
self.K,
|
|
||||||
self.E,
|
|
||||||
self.ep_size,
|
|
||||||
self.TOP_KS,
|
|
||||||
self.BLOCK_SIZE,
|
|
||||||
self.DTYPES,
|
|
||||||
self.SEEDS,
|
|
||||||
):
|
|
||||||
with self.subTest(
|
|
||||||
M=params[0],
|
|
||||||
N=params[1],
|
|
||||||
K=params[2],
|
|
||||||
E=params[3],
|
|
||||||
ep_size=params[4],
|
|
||||||
topk=params[5],
|
|
||||||
block_size=params[6],
|
|
||||||
dtype=params[7],
|
|
||||||
seed=params[8],
|
|
||||||
):
|
|
||||||
self._w8a8_block_fp8_ep_moe(*params)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main(verbosity=2)
|
|
||||||
@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
|
|||||||
)
|
)
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||||
expert_map[local_e:] = E
|
expert_map[local_e:] = -1
|
||||||
|
|
||||||
output = cutlass_moe(
|
output = cutlass_moe(
|
||||||
a,
|
a,
|
||||||
@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
|
|||||||
c_strides2,
|
c_strides2,
|
||||||
s_strides13,
|
s_strides13,
|
||||||
s_strides2,
|
s_strides2,
|
||||||
0,
|
local_e,
|
||||||
local_e - 1,
|
|
||||||
E,
|
|
||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
expert_map,
|
expert_map,
|
||||||
@@ -178,7 +176,7 @@ def cutlass_moe(
|
|||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids_: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
a_strides1: torch.Tensor,
|
a_strides1: torch.Tensor,
|
||||||
b_strides1: torch.Tensor,
|
b_strides1: torch.Tensor,
|
||||||
c_strides1: torch.Tensor,
|
c_strides1: torch.Tensor,
|
||||||
@@ -187,40 +185,32 @@ def cutlass_moe(
|
|||||||
c_strides2: torch.Tensor,
|
c_strides2: torch.Tensor,
|
||||||
s_strides13: torch.Tensor,
|
s_strides13: torch.Tensor,
|
||||||
s_strides2: torch.Tensor,
|
s_strides2: torch.Tensor,
|
||||||
start_expert_id: int,
|
num_local_experts: int,
|
||||||
end_expert_id: int,
|
|
||||||
E: int,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
):
|
):
|
||||||
local_topk_ids = topk_ids_
|
topk_ids = expert_map[topk_ids]
|
||||||
local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
|
|
||||||
device = a.device
|
device = a.device
|
||||||
|
|
||||||
local_num_experts = end_expert_id - start_expert_id + 1
|
|
||||||
expert_offsets = torch.empty(
|
expert_offsets = torch.empty(
|
||||||
(local_num_experts + 1), dtype=torch.int32, device=device
|
(num_local_experts + 1), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
problem_sizes1 = torch.empty(
|
problem_sizes1 = torch.empty(
|
||||||
(local_num_experts, 3), dtype=torch.int32, device=device
|
(num_local_experts, 3), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
problem_sizes2 = torch.empty(
|
problem_sizes2 = torch.empty(
|
||||||
(local_num_experts, 3), dtype=torch.int32, device=device
|
(num_local_experts, 3), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
return cutlass_w4a8_moe(
|
return cutlass_w4a8_moe(
|
||||||
start_expert_id,
|
|
||||||
end_expert_id,
|
|
||||||
E,
|
|
||||||
a,
|
a,
|
||||||
w1_q,
|
w1_q,
|
||||||
w2_q,
|
w2_q,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids_,
|
topk_ids,
|
||||||
local_topk_ids,
|
|
||||||
a_strides1,
|
a_strides1,
|
||||||
b_strides1,
|
b_strides1,
|
||||||
c_strides1,
|
c_strides1,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestEpMoE(CustomTestCase):
|
class TestEp(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
@@ -34,18 +34,6 @@ class TestEpMoE(CustomTestCase):
|
|||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_process_tree(cls.process.pid)
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
def test_mmlu(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="mmlu",
|
|
||||||
num_examples=64,
|
|
||||||
num_threads=32,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
self.assertGreaterEqual(metrics["score"], 0.5)
|
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
def test_mgsm_en(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
@@ -59,7 +47,7 @@ class TestEpMoE(CustomTestCase):
|
|||||||
self.assertGreaterEqual(metrics["score"], 0.8)
|
self.assertGreaterEqual(metrics["score"], 0.8)
|
||||||
|
|
||||||
|
|
||||||
class TestEpMoEFP8(CustomTestCase):
|
class TestEpDeepGEMM(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||||
@@ -76,6 +64,8 @@ class TestEpMoEFP8(CustomTestCase):
|
|||||||
"2",
|
"2",
|
||||||
"--quantization",
|
"--quantization",
|
||||||
"fp8",
|
"fp8",
|
||||||
|
"--moe-runner-backend",
|
||||||
|
"deep_gemm",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,18 +73,6 @@ class TestEpMoEFP8(CustomTestCase):
|
|||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_process_tree(cls.process.pid)
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
def test_mmlu(self):
|
|
||||||
args = SimpleNamespace(
|
|
||||||
base_url=self.base_url,
|
|
||||||
model=self.model,
|
|
||||||
eval_name="mmlu",
|
|
||||||
num_examples=64,
|
|
||||||
num_threads=32,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
self.assertGreaterEqual(metrics["score"], 0.5)
|
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
def test_mgsm_en(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ suites = {
|
|||||||
TestFile("test_modelopt_loader.py", 30),
|
TestFile("test_modelopt_loader.py", 30),
|
||||||
],
|
],
|
||||||
"per-commit-2-gpu": [
|
"per-commit-2-gpu": [
|
||||||
|
TestFile("ep/test_moe_ep.py", 140),
|
||||||
TestFile("lora/test_lora_tp.py", 116),
|
TestFile("lora/test_lora_tp.py", 116),
|
||||||
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
TestFile("rl/test_update_weights_from_distributed.py", 103),
|
||||||
TestFile("test_data_parallelism.py", 73),
|
TestFile("test_data_parallelism.py", 73),
|
||||||
|
|||||||
Reference in New Issue
Block a user