diff --git a/benchmark/kernels/fbgemm/README.md b/benchmark/kernels/fbgemm/README.md deleted file mode 100644 index e51356d8a..000000000 --- a/benchmark/kernels/fbgemm/README.md +++ /dev/null @@ -1,29 +0,0 @@ -## Benchmark FBGEMM Grouped GEMM - -Benchmark FBGEMM Grouped GEMM in both Triton and CUDA version and SGLang Triton Grouped GEMM, it will be used to compare the bandwidth of different implementations. - -### Requirements - -```shell -pip install fbgemm-gpu-genai -``` - -### Usage - -```bash -python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8 -``` - -For example, in H200, the Qwen2-57B-A14B-Instruct TP4 fp8w8a8 grouped gemm bandwidth result is as follows: - -```shell -grouped-gemm-performance: - batch_size FBGEMM Triton Grouped GEMM FP8 FBGEMM CUTLASS F8F8BF16 Rowwise SGLang Grouped GEMM FP8 -0 256.0 3704.841339 3042.626402 2254.725030 -1 512.0 3691.426346 3029.065684 2269.504543 -2 1024.0 3653.938629 2258.471467 2358.319020 -3 2048.0 3596.644313 2271.611904 2476.895397 -4 4096.0 3468.496435 2231.283986 2179.473910 -``` - -The theoretical peak bandwidth of H200 is 4.8 TB/s. Taking batch_size 256 as an example, the bandwidth of FBGEMM Triton Grouped GEMM FP8 is 3704.841339 GB/s, the bandwidth of FBGEMM CUTLASS F8F8BF16 Rowwise is 3042.626402 GB/s, and the bandwidth of SGLang Grouped GEMM FP8 is 2254.725030 GB/s. Therefore, FBGEMM Triton Grouped GEMM FP8 achieves 77.9% of H200's theoretical peak bandwidth, FBGEMM CUTLASS F8F8BF16 Rowwise achieves 63.4% of H200's theoretical peak bandwidth, and SGLang Grouped GEMM FP8 achieves 46.9% of H200's theoretical peak bandwidth. diff --git a/benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py b/benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py deleted file mode 100644 index 6e8c8dcf2..000000000 --- a/benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py +++ /dev/null @@ -1,516 +0,0 @@ -# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8 -import argparse - -import torch -import triton -from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import ( - quantize_fp8_row, - triton_quantize_fp8_row, -) -from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( - grouped_gemm as fbgemm_grouped_gemm, -) -from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import ( - grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise, -) -from transformers import AutoConfig - -from sglang.srt.layers.moe.ep_moe.kernels import ( - grouped_gemm_triton as sglang_grouped_gemm, -) - - -def get_model_config(model_name: str, tp_size: int): - config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - - if config.architectures[0] == "DbrxForCausalLM": - num_groups = config.ffn_config.moe_num_experts - intermediate_size = config.ffn_config.ffn_hidden_size - elif config.architectures[0] == "JambaForCausalLM": - num_groups = config.num_experts - intermediate_size = config.intermediate_size - elif config.architectures[0] == "Qwen2MoeForCausalLM": - num_groups = config.num_experts - intermediate_size = config.moe_intermediate_size - elif config.architectures[0] == "Qwen3MoeForCausalLM": - num_groups = config.num_experts - intermediate_size = config.moe_intermediate_size - elif config.architectures[0] in [ - "DeepseekV2ForCausalLM", - "DeepseekV3ForCausalLM", - ]: - num_groups = config.n_routed_experts - intermediate_size = config.moe_intermediate_size - elif config.architectures[0] == "Llama4ForConditionalGeneration": - num_groups = config.text_config.num_local_experts - intermediate_size = config.text_config.intermediate_size - elif config.architectures[0] in [ - "Grok1ForCausalLM", - "Grok1ImgGen", - "Grok1AForCausalLM", - ]: - num_groups = config.num_local_experts - intermediate_size = config.moe_intermediate_size - else: - num_groups = config.num_local_experts - intermediate_size = config.intermediate_size - - shape_configs = { - "num_groups": num_groups, - "hidden_size": config.hidden_size, - "intermediate_size": intermediate_size, - "dtype": config.torch_dtype, - } - print(f"{shape_configs=}") - return shape_configs - - -def create_test_data(batch_size, num_groups, hidden_size, intermediate_size): - torch.manual_seed(42) - - tokens_per_group = batch_size // num_groups - m_sizes = torch.full( - (num_groups,), tokens_per_group, dtype=torch.int32, device="cuda" - ) - - x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda") - - base_weights = torch.randn( - num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda" - ) - - w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size) - w_sglang = base_weights - - c_fbgemm = torch.empty( - batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda" - ) - c_sglang = torch.empty( - batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda" - ) - - seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda") - for i in range(1, num_groups + 1): - seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group - - weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda") - - return ( - x, - w_fbgemm, - w_sglang, - c_fbgemm, - c_sglang, - m_sizes, - seg_indptr, - weight_indices, - ) - - -def create_fp8_test_data( - batch_size, num_groups, hidden_size, intermediate_size, backend="triton" -): - """ - Create test data for FP8 grouped GEMM operations. - - Args: - batch_size: Total batch size - num_groups: Number of groups - hidden_size: Hidden dimension size - intermediate_size: Intermediate dimension size - backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM - - Returns: - For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale) - For cutlass: (x, wq, w_scale, m_sizes) - """ - torch.manual_seed(42) - - tokens_per_group = batch_size // num_groups - - # Create weight matrices for each group - w_list = [] - for _ in range(num_groups): - w = torch.randn( - intermediate_size, hidden_size, dtype=torch.float16, device="cuda" - ) - w_list.append(w) - - # Quantize weights using quantize_fp8_row for each group - wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list]) - - if backend == "triton": - # Triton format: concatenated weights - w_fp8 = torch.concat(wq_list, dim=0).contiguous() - w_scale = torch.concat(w_scale_list, dim=0).contiguous() - - # Create m_sizes as int32 for triton - m_sizes = torch.full( - (num_groups,), tokens_per_group, dtype=torch.int32, device="cuda" - ) - - # Create and quantize input - x_fp16 = torch.randn( - batch_size, hidden_size, dtype=torch.float16, device="cuda" - ) - x_fp8, x_scale = triton_quantize_fp8_row(x_fp16) - x_scale = x_scale.view(batch_size, -1) - - return x_fp8, w_fp8, m_sizes, x_scale, w_scale - - elif backend == "cutlass": - # CUTLASS format: stacked weights - wq = torch.stack(wq_list, dim=0).contiguous() - w_scale = torch.stack(w_scale_list, dim=0).contiguous() - - # Create m_sizes as int64 for cutlass - m_values = [tokens_per_group] * num_groups - m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda") - - # Create input data - separate for each group then concat - x_list = [] - for _ in range(num_groups): - x = torch.randn( - tokens_per_group, hidden_size, dtype=torch.float16, device="cuda" - ) - x_list.append(x) - - # Concatenate inputs into single tensor - x = torch.concat(x_list, dim=0).contiguous() - - return x, wq, w_scale, m_sizes - - else: - raise ValueError(f"Unsupported backend: {backend}") - - -def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype): - """ - Calculate memory bandwidth based on accessed expert weights. - - Args: - m_sizes: Tensor containing batch sizes for each group - hidden_size: Hidden dimension size - intermediate_size: Intermediate dimension size - dtype: Data type of weights - - Returns: - Memory size in bytes for accessed expert weights - """ - # Count non-zero groups (active experts) - if hasattr(m_sizes, "cpu"): - active_experts = torch.count_nonzero(m_sizes).item() - else: - active_experts = sum(1 for m in m_sizes if m > 0) - - # Calculate bytes per element based on dtype - if dtype in [torch.float16, torch.bfloat16]: - bytes_per_element = 2 - elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - bytes_per_element = 1 - elif dtype == torch.float32: - bytes_per_element = 4 - else: - # Default to 2 bytes for unknown dtypes - bytes_per_element = 2 - - # Memory per expert weight matrix - memory_per_expert = hidden_size * intermediate_size * bytes_per_element - - # Total memory for active experts - total_memory_bytes = active_experts * memory_per_expert - - return total_memory_bytes - - -def get_benchmark_config(use_fp8_w8a8=False): - if use_fp8_w8a8: - return { - "line_vals": [ - "fbgemm_triton_grouped_gemm_fp8", - "fbgemm_cutlass_f8f8bf16_rowwise", - "sglang_grouped_gemm", - ], - "line_names": [ - "FBGEMM Triton Grouped GEMM FP8", - "FBGEMM CUTLASS F8F8BF16 Rowwise", - "SGLang Grouped GEMM FP8", - ], - "styles": [("blue", "-"), ("orange", "-"), ("red", "-")], - } - else: - return { - "line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"], - "line_names": [ - "FBGEMM Triton Grouped GEMM BF16", - "SGLang Grouped GEMM BF16", - ], - "styles": [("blue", "-"), ("green", "-")], - } - - -def run_benchmark( - model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/" -): - config = get_benchmark_config(use_fp8_w8a8) - - benchmark_config = triton.testing.Benchmark( - x_names=["batch_size"], - x_vals=[256, 512, 1024, 2048, 4096], - line_arg="provider", - line_vals=config["line_vals"], - line_names=config["line_names"], - styles=config["styles"], - ylabel="Bandwidth (GB/s)", - plot_name="grouped-gemm-performance", - args={}, - ) - - @triton.testing.perf_report(benchmark_config) - def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False): - print(f"Benchmarking {provider} with batch_size={batch_size}") - torch.cuda.manual_seed_all(0) - - num_groups = model_config["num_groups"] - hidden_size = model_config["hidden_size"] - intermediate_size = model_config["intermediate_size"] - - if provider == "fbgemm_triton_grouped_gemm_fp8": - try: - test_data = create_fp8_test_data( - batch_size, - num_groups, - hidden_size, - intermediate_size, - backend="triton", - ) - x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data - - # Calculate memory bandwidth - memory_bytes = calculate_memory_bandwidth( - m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn - ) - - def run_func(): - return fbgemm_grouped_gemm_fp8_rowwise( - x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True - ) - - except Exception as e: - print(f"FP8 not supported, skipping: {e}") - return float("inf"), float("inf"), float("inf") - - elif provider == "fbgemm_cutlass_f8f8bf16_rowwise": - try: - test_data = create_fp8_test_data( - batch_size, - num_groups, - hidden_size, - intermediate_size, - backend="cutlass", - ) - x, wq, w_scale, m_sizes = test_data - - # Calculate memory bandwidth - memory_bytes = calculate_memory_bandwidth( - m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn - ) - - # Quantize input using triton_quantize_fp8_row - xq, x_scale = triton_quantize_fp8_row(x) - x_scale = x_scale.view(batch_size, -1) - - def run_func(): - return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked( - xq, wq, x_scale, w_scale, m_sizes - ) - - except Exception as e: - print( - f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, " - f"skipping: {e}" - ) - return float("inf"), float("inf"), float("inf") - else: - test_data = create_test_data( - batch_size, num_groups, hidden_size, intermediate_size - ) - ( - x, - w_fbgemm, - w_sglang, - c_fbgemm, - c_sglang, - m_sizes, - seg_indptr, - weight_indices, - ) = test_data - - # Calculate memory bandwidth for BF16 operations - memory_bytes = calculate_memory_bandwidth( - m_sizes, hidden_size, intermediate_size, torch.bfloat16 - ) - - if provider == "fbgemm_triton_grouped_gemm": - - def run_func(): - return fbgemm_grouped_gemm( - x, w_fbgemm, m_sizes, use_fast_accum=True - ) - - else: - - def run_func(): - return sglang_grouped_gemm( - x, - w_sglang, - c_sglang, - num_groups, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices, - c_dtype=c_sglang.dtype, - ) - - for _ in range(10): - try: - run_func() - except Exception as e: - print(f"Error during warmup for {provider}: {e}") - return float("inf"), float("inf"), float("inf") - - torch.cuda.synchronize() - - try: - quantiles = [0.5, 0.2, 0.8] - ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles) - - # Convert time (ms) to bandwidth (GB/s) - # Bandwidth = Memory (bytes) / Time (seconds) - # Convert ms to seconds and bytes to GB (1e9) - gb_per_s = (memory_bytes / 1e9) / (ms / 1000) - # min bandwidth = max time, max bandwidth = min time - min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000) - max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000) - - return gb_per_s, min_gb_per_s, max_gb_per_s - except Exception as e: - print(f"Error during benchmarking for {provider}: {e}") - return 0.0, 0.0, 0.0 - - dynamic_benchmark.run( - show_plots=True, - print_data=True, - save_path=save_path, - model_config=model_config, - use_fp8_w8a8=use_fp8_w8a8, - ) - - -def verify_correctness(model_config): - print("Verifying correctness...") - batch_size = 128 - num_groups = model_config["num_groups"] - hidden_size = model_config["hidden_size"] - intermediate_size = model_config["intermediate_size"] - - test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size) - ( - x, - w_fbgemm, - w_sglang, - c_fbgemm, - c_sglang, - m_sizes, - seg_indptr, - weight_indices, - ) = test_data - - result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True) - - result_sglang = sglang_grouped_gemm( - x, - w_sglang, - c_sglang, - num_groups, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices, - c_dtype=c_sglang.dtype, - ) - - if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3): - print("✓ BF16 Correctness verification passed!") - else: - max_diff = torch.max(torch.abs(result_fbgemm - result_sglang)) - print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}") - return False - - return True - - -def main(): - parser = argparse.ArgumentParser( - description="Benchmark FBGEMM vs SGLang Grouped GEMM" - ) - parser.add_argument( - "--model", - type=str, - default="mistralai/Mixtral-8x7B-Instruct-v0.1", - help="Model name to get configuration from", - ) - parser.add_argument( - "--tp-size", type=int, default=1, help="Tensor parallelism size" - ) - parser.add_argument( - "--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark" - ) - parser.add_argument( - "--save-path", - type=str, - default="./benchmark_grouped_gemm/", - help="Path to save benchmark results", - ) - parser.add_argument( - "--verify-correctness", - action="store_true", - help="Verify correctness before benchmarking", - ) - - args = parser.parse_args() - - try: - model_config = get_model_config(args.model, args.tp_size) - except Exception as e: - print(f"Failed to get model config: {e}") - print("Using default configuration...") - model_config = { - "num_groups": 8, - "hidden_size": 4096, - "intermediate_size": 14336, - "dtype": torch.bfloat16, - } - - print("Running benchmark with:") - print(f" num_groups: {model_config['num_groups']}") - print(f" hidden_size: {model_config['hidden_size']}") - print(f" intermediate_size: {model_config['intermediate_size']}") - print(f" use_fp8_w8a8: {args.use_fp8_w8a8}") - - if args.verify_correctness: - if not verify_correctness(model_config): - print("Correctness verification failed. Exiting...") - return - - try: - run_benchmark( - model_config=model_config, - use_fp8_w8a8=args.use_fp8_w8a8, - save_path=args.save_path, - ) - except Exception as e: - print(f"Benchmark failed: {e}") - - -if __name__ == "__main__": - main() diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 1f1f801ac..ad8dc9405 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -246,7 +246,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s |-----------|-------------|----------| | `--ep-size` | The expert parallelism size. | 1 | | `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none | -| `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' | +| `--moe-runner-backend` | Select the runner backend for MoE. | auto | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | | `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None | diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py index 216424eea..e1507be18 100644 --- a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -13,22 +13,18 @@ from sgl_kernel import ( from sglang.srt.layers.moe.ep_moe.kernels import ( post_reorder_triton_kernel_for_cutlass_moe, pre_reorder_triton_kernel_for_cutlass_moe, - run_cutlass_moe_ep_preproess, + run_moe_ep_preproess, ) def cutlass_w4a8_moe( - start_expert_id: int, - end_expert_id: int, - total_num_experts: int, a: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids_: torch.Tensor, - local_topk_ids: torch.Tensor, + topk_ids: torch.Tensor, a_strides1: torch.Tensor, b_strides1: torch.Tensor, c_strides1: torch.Tensor, @@ -64,6 +60,7 @@ def cutlass_w4a8_moe( - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. Shape: [num_experts, N // 512, K * 4] - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The ids of each token->expert mapping. - a_strides1 (torch.Tensor): The input strides of the first grouped gemm. - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm. - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. @@ -83,7 +80,7 @@ def cutlass_w4a8_moe( Returns: - torch.Tensor: The fp8 output tensor after applying the MoE layer. """ - assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert w1_q.dtype == torch.int8 assert w2_q.dtype == torch.int8 assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1" @@ -96,20 +93,21 @@ def cutlass_w4a8_moe( assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch" assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch" assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch" - num_experts = w1_q.size(0) + num_local_experts = w1_q.size(0) m = a.size(0) k = w1_q.size(2) * 2 # w1_q is transposed and packed n = w2_q.size(2) * 2 # w2_q is transposed and packed - topk = topk_ids_.size(1) + topk = topk_ids.size(1) if apply_router_weight_on_input: assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1" device = a.device + topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids) - _, src2dst, _ = run_cutlass_moe_ep_preproess( - local_topk_ids, - num_experts, + _, src2dst, _ = run_moe_ep_preproess( + topk_ids, + num_local_experts, ) gateup_input = torch.empty( @@ -122,9 +120,9 @@ def cutlass_w4a8_moe( a, gateup_input, src2dst, - local_topk_ids, + topk_ids, a1_scale, - total_num_experts, + num_local_experts, topk, k, BLOCK_SIZE=512, @@ -133,16 +131,16 @@ def cutlass_w4a8_moe( # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel, # they are kept to allow for a quick switch of the permutation logic # from the current triton kernel implementation to the cutlass-based one if needed. - a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) get_cutlass_w4a8_moe_mm_data( - local_topk_ids, + topk_ids, expert_offsets, problem_sizes1, problem_sizes2, a_map, c_map, - num_experts, + num_local_experts, n, k, ) @@ -195,12 +193,11 @@ def cutlass_w4a8_moe( c2, output, src2dst, - local_topk_ids, + topk_ids, topk_weights, - num_experts, topk, + num_local_experts, k, - 0, BLOCK_SIZE=512, ) return output diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index d8e221d5c..ef4262a1c 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -130,28 +130,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int): @triton.jit def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): - expert = tl.program_id(0) + expert_id_minus_1 = tl.program_id(0) - 1 low = 0 high = num_toks - 1 target_location = -1 while low <= high: mid = (low + high) // 2 - if tl.load(reorder_topk_ids + mid) > expert: + if tl.load(reorder_topk_ids + mid) > expert_id_minus_1: high = mid - 1 else: low = mid + 1 target_location = mid - tl.store(seg_indptr + expert + 1, target_location + 1) + tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1) -def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): +def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) - seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + seg_indptr = torch.zeros( + num_local_experts + 1, device=topk_ids.device, dtype=torch.int64 + ) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) - compute_seg_indptr_triton_kernel[(num_experts,)]( + compute_seg_indptr_triton_kernel[(num_local_experts,)]( reorder_topk_ids, seg_indptr, topk_ids.numel() ) @@ -164,25 +166,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): return reorder_topk_ids, src2dst, seg_indptr -def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int): - reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True) - - seg_indptr = torch.zeros( - local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64 - ) - src2dst = torch.empty( - local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32 - ) - - BLOCK_SIZE = 512 - grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),) - compute_src2dst_triton_kernel[grid]( - reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE - ) - - return reorder_topk_ids, src2dst, seg_indptr - - @triton.jit def pre_reorder_triton_kernel_for_cutlass_moe( input_ptr, @@ -190,52 +173,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe( src2dst_ptr, topk_ids_ptr, a1_scales_ptr, - num_experts, + num_local_experts, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty - src_idx = tl.program_id(0) - src2dst_ptr = src2dst_ptr + src_idx * topk - topk_ids_ptr = topk_ids_ptr + src_idx * topk - - src_ptr = input_ptr + src_idx * hidden_size - for idx in range(topk): - expert_id = tl.load(topk_ids_ptr + idx) - if expert_id != num_experts: - if a1_scales_ptr is not None: - scale = 1.0 / tl.load(a1_scales_ptr) - else: - scale = 1.0 - - dst_idx = tl.load(src2dst_ptr + idx) - dst_ptr = gateup_input_ptr + dst_idx * hidden_size - for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + tl.arange(0, BLOCK_SIZE) - mask = offset < hidden_size - in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) - out_data = (in_data * scale).to(OutDtype) - tl.store(dst_ptr + offset, out_data, mask=mask) - - -@triton.jit -def pre_reorder_triton_kernel( - input_ptr, - gateup_input_ptr, - src2dst_ptr, - topk_ids_ptr, - a1_scales_ptr, - start_expert_id, - end_expert_id, - topk, - hidden_size, - BLOCK_SIZE: tl.constexpr, - use_per_token_if_dynamic: tl.constexpr, -): - OutDtype = gateup_input_ptr.dtype.element_ty - src_idx_int32 = tl.program_id(0) src_idx = src_idx_int32.to(tl.int64) src2dst_ptr = src2dst_ptr + src_idx * topk @@ -244,15 +188,11 @@ def pre_reorder_triton_kernel( vec = tl.arange(0, BLOCK_SIZE) - if a1_scales_ptr is not None and use_per_token_if_dynamic: - scale = 1.0 / tl.load(a1_scales_ptr + src_idx) - for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) - if expert_id >= start_expert_id and expert_id <= end_expert_id: + if expert_id != num_local_experts: if a1_scales_ptr is not None: - if not use_per_token_if_dynamic: - scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) + scale = 1.0 / tl.load(a1_scales_ptr) else: scale = 1.0 @@ -267,52 +207,6 @@ def pre_reorder_triton_kernel( tl.store(dst_ptr + offset, out_data, mask=mask) -@triton.jit -def silu_and_mul_triton_kernel( - gateup_output, - down_input, - hidden_size, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, - BLOCK_SIZE: tl.constexpr, -): - InDtype = gateup_output.dtype.element_ty - OutDtype = down_input.dtype.element_ty - - half_hidden_size = hidden_size // 2 - - pid = tl.program_id(0) - expert_id = tl.load(reorder_topk_ids + pid) - if expert_id >= start_expert_id and expert_id <= end_expert_id: - gateup_output_ptr = gateup_output + pid * hidden_size - gate_output_ptr = gateup_output_ptr - up_output_ptr = gateup_output_ptr + half_hidden_size - down_input_ptr = down_input + pid * half_hidden_size - - if scales is not None: - scale = tl.load(scales + expert_id - start_expert_id) - scale = (1 / scale).to(InDtype) - else: - scale = 1 - - for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): - offset = start_offset + tl.arange(0, BLOCK_SIZE) - mask = offset < half_hidden_size - - gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) - up_output = tl.load(up_output_ptr + offset, mask=mask) - - # silu & mul & quantize - gate_output = gate_output * tl.sigmoid(gate_output) - gate_output = gate_output.to(InDtype) - - silu_mul_output = gate_output * up_output * scale - silu_mul_output = silu_mul_output.to(OutDtype) - tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) - - # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py @triton.jit def _silu_and_mul_post_quant_kernel( @@ -461,70 +355,44 @@ def silu_and_mul_masked_post_quant_fwd( @triton.jit -def tanh(x): - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def gelu_and_mul_triton_kernel( - gateup_output, - down_input, +def post_reorder_triton_kernel_for_cutlass_moe( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + topk, + num_local_experts, hidden_size, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, BLOCK_SIZE: tl.constexpr, ): - InDtype = gateup_output.dtype.element_ty - OutDtype = down_input.dtype.element_ty + InDtype = down_output_ptr.dtype.element_ty - half_hidden_size = hidden_size // 2 + src_idx_int32 = tl.program_id(0) + src_idx = src_idx_int32.to(tl.int64) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk - pid = tl.program_id(0) - expert_id = tl.load(reorder_topk_ids + pid) - if expert_id >= start_expert_id and expert_id <= end_expert_id: - gateup_output_ptr = gateup_output + pid * hidden_size - gate_output_ptr = gateup_output_ptr - up_output_ptr = gateup_output_ptr + half_hidden_size - down_input_ptr = down_input + pid * half_hidden_size + store_ptr = output_ptr + src_idx * hidden_size - if scales is not None: - scale = tl.load(scales + expert_id - start_expert_id) - scale = (1 / scale).to(InDtype) - else: - scale = 1 + vec = tl.arange(0, BLOCK_SIZE) - for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): - offset = start_offset + tl.arange(0, BLOCK_SIZE) - mask = offset < half_hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + vec + mask = offset < hidden_size - gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) - up_output = tl.load(up_output_ptr + offset, mask=mask) - - # gelu & mul & quantize - # https://pytorch.org/docs/stable/generated/torch.nn.GELU.html - # sqrt(2/pi) - kAlpha = 0.7978845608028654 - gate_output = ( - 0.5 - * gate_output - * ( - 1 - + tanh( - kAlpha - * ( - gate_output - + 0.044715 * gate_output * gate_output * gate_output - ) - ) - ) - ) - gate_output = gate_output.to(InDtype) - - gelu_mul_output = gate_output * up_output * scale - gelu_mul_output = gelu_mul_output.to(OutDtype) - tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask) + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id != num_local_experts: + dst_idx_int32 = tl.load(src2dst_ptr + idx) + dst_idx = dst_idx_int32.to(tl.int64) + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) @triton.jit @@ -534,64 +402,8 @@ def post_reorder_triton_kernel( src2dst_ptr, topk_ids_ptr, topk_weights_ptr, - start_expert_id, - end_expert_id, topk, hidden_size, - dst_start, - BLOCK_SIZE: tl.constexpr, -): - InDtype = down_output_ptr.dtype.element_ty - - src_idx_int32 = tl.program_id(0) - src_idx = src_idx_int32.to(tl.int64) - src2dst_ptr = src2dst_ptr + src_idx * topk - topk_ids_ptr = topk_ids_ptr + src_idx * topk - topk_weights_ptr = topk_weights_ptr + src_idx * topk - - computed = False - store_ptr = output_ptr + src_idx * hidden_size - - vec = tl.arange(0, BLOCK_SIZE) - - for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + vec - mask = offset < hidden_size - - sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) - for idx in range(topk): - expert_id = tl.load(topk_ids_ptr + idx) - if expert_id >= start_expert_id and expert_id <= end_expert_id: - computed = True - dst_idx_int32 = tl.load(src2dst_ptr + idx) - dst_idx = dst_idx_int32.to(tl.int64) - dst_idx = dst_idx - dst_start - weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) - load_ptr = down_output_ptr + dst_idx * hidden_size - in_data = tl.load(load_ptr + offset, mask=mask) - sum_vec += in_data * weigh_scale - tl.store(store_ptr + offset, sum_vec, mask=mask) - - if computed == False: - for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): - offset = start_offset + vec - mask = offset < hidden_size - tl.store( - store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask - ) - - -@triton.jit -def post_reorder_triton_kernel_for_cutlass_moe( - down_output_ptr, - output_ptr, - src2dst_ptr, - topk_ids_ptr, - topk_weights_ptr, - num_experts, - topk, - hidden_size, - dst_start, BLOCK_SIZE: tl.constexpr, ): InDtype = down_output_ptr.dtype.element_ty @@ -613,10 +425,9 @@ def post_reorder_triton_kernel_for_cutlass_moe( sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) - if expert_id != num_experts: + if expert_id > 0: dst_idx_int32 = tl.load(src2dst_ptr + idx) dst_idx = dst_idx_int32.to(tl.int64) - dst_idx = dst_idx - dst_start weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) load_ptr = down_output_ptr + dst_idx * hidden_size in_data = tl.load(load_ptr + offset, mask=mask) @@ -624,232 +435,6 @@ def post_reorder_triton_kernel_for_cutlass_moe( tl.store(store_ptr + offset, sum_vec, mask=mask) -@triton.jit -def compute_m_range( - pid, - batch_size, - seg_indptr, - weight_indices, - m_num_tiles_indptr, - BLOCK_SIZE_M: tl.constexpr, -): - idx = 0 - for bs in range(batch_size): - tiles = tl.load(m_num_tiles_indptr + bs) - if pid >= tiles: - idx = bs - - idx_start = tl.load(m_num_tiles_indptr + idx) - - m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M - m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) - expert_id = tl.load(weight_indices + idx) - return m_range_start, m_range_end, expert_id - - -@triton.jit -def grouped_gemm_triton_kernel( - a, - b, - c, - batch_size, - N, - K, - seg_indptr, - weight_indices, - m_num_tiles_indptr, - scale_a, - scale_b, - use_fp8_w8a8: tl.constexpr, - group_n: tl.constexpr, - group_k: tl.constexpr, - a_stride_0: tl.constexpr, - b_stride_0: tl.constexpr, - b_stride_1: tl.constexpr, - as_stride_0: tl.constexpr, - as_stride_1: tl.constexpr, - bs_stride_0: tl.constexpr, - bs_stride_2: tl.constexpr, - bs_stride_1: tl.constexpr, - use_per_token_if_dynamic: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - c_dtype = c.dtype.element_ty - - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - total_m_block = tl.load(m_num_tiles_indptr + batch_size) - if pid_m >= total_m_block: - return - - m_range_start, m_range_end, expert_id = compute_m_range( - pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M - ) - if m_range_end - m_range_start == 0: - return - - n_range_start = pid_n * BLOCK_SIZE_N - n_range_end = min(n_range_start + BLOCK_SIZE_N, N) - - offs_am = tl.arange(0, BLOCK_SIZE_M) - offs_bn = tl.arange(0, BLOCK_SIZE_N) - - offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) - offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) - offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] - b_ptr = b + ( - (expert_id * b_stride_0) - + (n_range_start + offs_bn[:, None]) * b_stride_1 - + offs_k[None, :] - ) - - if group_k > 0 and group_n > 0: - a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0 - offs_bsn = (n_range_start + offs_bn) // group_n - b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1 - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a_tile = tl.load( - a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 - ) - b_tile = tl.load( - b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 - ) - - if group_k > 0 and group_n > 0: - k_start = k * BLOCK_SIZE_K - offs_ks = k_start // group_k - a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1) - b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2) - accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :] - else: - accumulator = tl.dot(a_tile, b_tile.T, accumulator) - a_ptr += BLOCK_SIZE_K - b_ptr += BLOCK_SIZE_K - - if use_fp8_w8a8 and not (group_k > 0 and group_n > 0): - if use_per_token_if_dynamic: - scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None])) - else: - scale_a_value = tl.load(scale_a + expert_id) - scale_b_value = tl.load(scale_b + expert_id) - accumulator *= scale_a_value * scale_b_value - - c_tile = accumulator.to(c_dtype) - - offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) - c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] - c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) - tl.store(c_ptr, c_tile, mask=c_mask) - - -@triton.jit -def compute_m_num_tiles_indptr( - m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr -): - for bs in range(batch_size): - m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) - cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) - pre_num_tiles = tl.load(m_num_tiles_indptr + bs) - tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) - - -def grouped_gemm_triton( - a: torch.Tensor, - b: torch.Tensor, - c: torch.Tensor, - batch_size: int, - weight_column_major: bool, - seg_indptr: Optional[torch.Tensor] = None, - weight_indices: Optional[torch.Tensor] = None, - use_fp8_w8a8: bool = False, - scale_a: torch.Tensor = None, - scale_b: torch.Tensor = None, - block_shape: Optional[List[int]] = None, - c_dtype=None, - use_per_token_if_dynamic: bool = True, -): - assert weight_column_major == True # TODO: more - if use_fp8_w8a8 and block_shape is None: - assert scale_a is not None and scale_b is not None - - if block_shape is not None: - a_original = a - - assert len(block_shape) == 2 - block_n, block_k = block_shape[0], block_shape[1] - a, scale_a = per_token_group_quant_fp8(a, block_k) - - assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] - assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] - assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1] - - dispose_tensor(a_original) - - # TODO: adjust config or tune kernel - # Reduce block size to prevent L40 shared memory overflow. - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - } - - m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) - compute_m_num_tiles_indptr[(1,)]( - m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] - ) - - if c is None: - assert c_dtype is not None - c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype) - - grid = lambda META: ( - triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, - triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), - ) - - if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic: - assert ( - scale_a.shape[0] == a.shape[0] - ), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}" - - grouped_gemm_triton_kernel[grid]( - a, - b, - c, - batch_size, - b.size(1), - b.size(2), - seg_indptr, - weight_indices, - m_num_tiles_indptr, - scale_a, - scale_b, - use_fp8_w8a8, - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - a.stride(0), - b.stride(0), - b.stride(1), - scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0, - scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0, - scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0, - scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0, - scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0, - use_per_token_if_dynamic, - **config, - ) - return c - - @triton.jit def _fwd_kernel_ep_scatter_1( num_recv_tokens_per_expert, @@ -1234,7 +819,7 @@ def deepgemm_compute_src2dst_triton_kernel( mask = dst_id < num_toks src_id = tl.load(reorder_ids + dst_id, mask=mask) expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks)) - expert_dst_start = tl.load(seg_indptr + expert_id) + expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0)) expert_dst_offset = dst_id - expert_dst_start dst_id = expert_id * m_max + expert_dst_offset tl.store(src2dst + src_id, dst_id, mask=mask) @@ -1248,10 +833,7 @@ def fill_gateup_input_triton_kernel( gateup_input_scale_ptr, src2dst_ptr, topk_ids_ptr, - start_expert_id, - end_expert_id, topk, - m_max, hidden_size, scale_size, BLOCK_SIZE: tl.constexpr, @@ -1267,10 +849,9 @@ def fill_gateup_input_triton_kernel( vec = tl.arange(0, BLOCK_SIZE) for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) - if expert_id >= start_expert_id and expert_id <= end_expert_id: + if expert_id >= 0: dst_idx_int32 = tl.load(src2dst_ptr + idx) dst_idx = dst_idx_int32.to(tl.int64) - dst_idx = dst_idx - start_expert_id * m_max dst_ptr = gateup_input_ptr + dst_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + vec @@ -1287,31 +868,31 @@ def fill_gateup_input_triton_kernel( def moe_ep_deepgemm_preprocess( topk_ids: torch.Tensor, - num_experts: int, + num_local_experts: int, hidden_states: torch.Tensor, top_k: int, - start_expert_id, - end_expert_id, block_shape, output_dtype: torch.dtype = torch.float8_e4m3fn, ): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) - seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + seg_indptr = torch.zeros( + num_local_experts + 1, device=topk_ids.device, dtype=torch.int64 + ) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) - masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32) + masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32) - compute_seg_indptr_triton_kernel[(num_experts,)]( + compute_seg_indptr_triton_kernel[(num_local_experts + 1,)]( reorder_topk_ids, seg_indptr, topk_ids.numel() ) grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),) - compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m) + compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m) # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165 - m_max = (hidden_states.size(0) + 255) // 256 * 256 - expected_m = (topk_ids.numel() + num_experts - 1) // num_experts + m_max = (hidden_states.size(0) // 256 + 1) * 256 + expected_m = (topk_ids.numel() - 1) // num_local_experts + 1 gateup_input = torch.empty( - (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)), + (num_local_experts, m_max, hidden_states.size(1)), device=hidden_states.device, dtype=output_dtype, ) @@ -1330,6 +911,8 @@ def moe_ep_deepgemm_preprocess( block_shape = [128, 128] assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] + + # TODO: fuse this with the preprocess hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k) gateup_input_scale = torch.empty( @@ -1345,18 +928,14 @@ def moe_ep_deepgemm_preprocess( gateup_input_scale, src2dst, topk_ids, - start_expert_id, - end_expert_id, top_k, - m_max, hidden_states.size(1), scale.size(1), BLOCK_SIZE=1024, ) return ( - m_max, - masked_m[start_expert_id : (end_expert_id + 1)], + masked_m, expected_m, src2dst, gateup_input, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 76f48bc4b..30e3faab3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,14 +1,10 @@ from __future__ import annotations import logging -from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch -import triton -import triton.language as tl -from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.layers.moe import ( get_deepep_mode, get_moe_a2a_backend, @@ -18,13 +14,10 @@ from sglang.srt.layers.moe import ( from sglang.srt.layers.moe.ep_moe.kernels import ( ep_gather, ep_scatter, - moe_ep_deepgemm_preprocess, - post_reorder_triton_kernel, silu_and_mul_masked_post_quant_fwd, tma_align_input_scale, ) from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE -from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8 import Fp8Config @@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import ( CUTEDSL_MOE_NVFP4_DISPATCH, ModelOptNvFp4FusedMoEMethod, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.offloader import get_offloader from sglang.srt.single_batch_overlap import DownGemmOverlapArgs -from sglang.srt.utils import ( - ceil_div, - dispose_tensor, - get_bool_env_var, - get_int_env_var, - is_cuda, - is_hip, - is_npu, -) +from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import ( @@ -72,275 +56,7 @@ if _use_aiter: logger = logging.getLogger(__name__) -# TODO(kaixih@nvidia): ideally we should merge this logic into -# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. -@torch.compile -def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor: - temp = x.to(torch.float32).view(torch.int32) - exp = torch.bitwise_right_shift(temp, 23) - mant = torch.bitwise_and(temp, 0x7FFFFF) - is_ru = torch.logical_and( - torch.logical_and((mant > 0), (exp != 0xFE)), - ~torch.logical_and((exp == 0), (mant <= 0x400000)), - ) - exp = torch.where(is_ru, exp + 1, exp) - new_x = exp.to(torch.uint8).view(torch.int) - return new_x.transpose(1, 2).contiguous().transpose(1, 2) - - -class EPMoE(FusedMoE): - """ - MoE Expert Parallel Impl - - - """ - - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - layer_id: int, - num_fused_shared_experts: int = 0, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - activation: str = "silu", - routed_scaling_factor: Optional[float] = None, - gemm1_alpha: Optional[float] = None, - gemm1_clamp_limit: Optional[float] = None, - with_bias: bool = False, - ): - super().__init__( - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_fused_shared_experts=num_fused_shared_experts, - layer_id=layer_id, - top_k=top_k, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix, - activation=activation, - # apply_router_weight_on_input=apply_router_weight_on_input, - routed_scaling_factor=routed_scaling_factor, - gemm1_alpha=gemm1_alpha, - gemm1_clamp_limit=gemm1_clamp_limit, - with_bias=with_bias, - ) - - self.intermediate_size = intermediate_size - - if isinstance(quant_config, Fp8Config): - self.use_block_quant = getattr(self.quant_method, "block_quant", False) - self.block_shape = ( - self.quant_method.quant_config.weight_block_size - if self.use_block_quant - else None - ) - self.use_fp8_w8a8 = True - self.fp8_dtype = torch.float8_e4m3fn - self.activation_scheme = quant_config.activation_scheme - else: - self.use_fp8_w8a8 = False - self.use_block_quant = False - self.block_shape = None - self.activation_scheme = None - - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: - return self.forward_deepgemm(hidden_states, topk_output) - else: - return super().forward(hidden_states, topk_output) - - def forward_deepgemm( - self, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - ): - - self.w13_weight_fp8 = ( - self.w13_weight, - ( - self.w13_weight_scale_inv - if self.use_block_quant - else self.w13_weight_scale - ), - ) - self.w2_weight_fp8 = ( - self.w2_weight, - self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, - ) - - assert self.quant_method is not None - assert self.moe_runner_config.activation == "silu" - - hidden_states_shape = hidden_states.shape - hidden_states_dtype = hidden_states.dtype - hidden_states_device = hidden_states.device - - topk_weights, topk_ids, _ = topk_output - - if not self.use_block_quant: - # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm - scale_block_size = 128 - w13_weight_scale_n = 2 * ( - (self.intermediate_size + scale_block_size - 1) // scale_block_size - ) - w13_weight_scale_k = ( - hidden_states_shape[-1] + scale_block_size - 1 - ) // scale_block_size - w13_weight_scale = ( - self.w13_weight_scale.unsqueeze(1) - .repeat_interleave(w13_weight_scale_n, dim=1) - .unsqueeze(2) - .repeat_interleave(w13_weight_scale_k, dim=2) - ) - self.w13_weight_fp8 = ( - self.w13_weight, - w13_weight_scale, - ) - w2_weight_scale_n = ( - hidden_states_shape[-1] + scale_block_size - 1 - ) // scale_block_size - w2_weight_scale_k = ( - self.intermediate_size + scale_block_size - 1 - ) // scale_block_size - w2_weight_scale = ( - self.w2_weight_scale.unsqueeze(1) - .repeat_interleave(w2_weight_scale_n, dim=1) - .unsqueeze(2) - .repeat_interleave(w2_weight_scale_k, dim=2) - ) - self.w2_weight_fp8 = ( - self.w2_weight, - w2_weight_scale, - ) - - # PreReorder - m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = ( - moe_ep_deepgemm_preprocess( - topk_ids, - self.num_experts, - hidden_states, - self.top_k, - self.start_expert_id, - self.end_expert_id, - self.block_shape, - ) - ) - - dispose_tensor(hidden_states) - - if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: - b, s_mn, s_k = gateup_input_scale.shape - assert ( - s_mn % 4 == 0 and s_k % 4 == 0 - ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})" - - # GroupGemm-0 - gateup_input_fp8 = ( - gateup_input, - ( - _cast_to_e8m0_with_rounding_up(gateup_input_scale) - if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( - gateup_input_scale - ) - ), - ) - num_groups, m, k = gateup_input_fp8[0].size() - n = self.w13_weight.size(1) - gateup_output = torch.empty( - (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 - ) - deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - gateup_input_fp8, - self.w13_weight_fp8, - gateup_output, - masked_m, - expected_m, - ) - del gateup_input - del gateup_input_fp8 - - # Act - down_input = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2, - ), - device=hidden_states_device, - dtype=self.fp8_dtype, - ) - scale_block_size = 128 - down_input_scale = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2 // scale_block_size, - ), - device=hidden_states_device, - dtype=torch.float32, - ) - silu_and_mul_masked_post_quant_fwd( - gateup_output, - down_input, - down_input_scale, - scale_block_size, - masked_m, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - ) - del gateup_output - - # GroupGemm-1 - n = self.w2_weight.size(1) - down_input_fp8 = ( - down_input, - ( - down_input_scale - if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale) - ), - ) - down_output = torch.empty( - (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 - ) - deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - down_input_fp8, - self.w2_weight_fp8, - down_output, - masked_m, - expected_m, - ) - del down_input - del down_input_fp8 - - # PostReorder - output = torch.empty( - hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device - ) - post_reorder_triton_kernel[(hidden_states_shape[0],)]( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - self.start_expert_id, - self.end_expert_id, - self.top_k, - hidden_states_shape[1], - m_max * self.start_expert_id, - BLOCK_SIZE=512, - ) - if self.moe_runner_config.routed_scaling_factor is not None: - output *= self.moe_runner_config.routed_scaling_factor - return output - - -class DeepEPMoE(EPMoE): +class DeepEPMoE(FusedMoE): """ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) """ @@ -374,6 +90,15 @@ class DeepEPMoE(EPMoE): activation=activation, routed_scaling_factor=routed_scaling_factor, ) + + if isinstance(quant_config, Fp8Config): + self.use_block_quant = getattr(self.quant_method, "block_quant", False) + self.use_fp8_w8a8 = True + self.fp8_dtype = torch.float8_e4m3fn + else: + self.use_fp8_w8a8 = False + self.use_block_quant = False + self.deepep_mode = get_deepep_mode() # TODO: move to the beginning of the file @@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE): N = self.w13_weight.size(1) scale_block_size = 128 - # TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass) w13_weight_fp8 = ( self.w13_weight, ( @@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): return FlashInferFusedMoE if get_moe_runner_backend().is_flashinfer_cutlass(): return FusedMoE - if get_moe_expert_parallel_world_size() > 1: - return EPMoE return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 132a0c31f..9cdfbc86c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module): self.moe_tp_rank = get_moe_tensor_parallel_rank() assert num_experts % self.moe_ep_size == 0 self.num_local_experts = num_experts // self.moe_ep_size - self.start_expert_id = self.moe_ep_rank * self.num_local_experts - self.end_expert_id = self.start_expert_id + self.num_local_experts - 1 + if self.moe_ep_size > 1: # TODO(ch-wan): support shared experts fusion # Create a tensor of size num_experts filled with -1 diff --git a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py new file mode 100644 index 000000000..9bc3824b9 --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional + +import torch + +from sglang.srt.layers.moe.moe_runner.base import ( + MoeQuantInfo, + MoeRunnerConfig, + MoeRunnerCore, + RunnerInput, + RunnerOutput, + register_post_permute, + register_pre_permute, +) +from sglang.srt.layers.moe.utils import MoeRunnerBackend +from sglang.srt.utils import dispose_tensor + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher.standard import ( + StandardCombineInput, + StandardDispatchOutput, + ) + + +# TODO(kaixih@nvidia): ideally we should merge this logic into +# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale. +@torch.compile +def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor: + temp = x.to(torch.float32).view(torch.int32) + exp = torch.bitwise_right_shift(temp, 23) + mant = torch.bitwise_and(temp, 0x7FFFFF) + is_ru = torch.logical_and( + torch.logical_and((mant > 0), (exp != 0xFE)), + ~torch.logical_and((exp == 0), (mant <= 0x400000)), + ) + exp = torch.where(is_ru, exp + 1, exp) + new_x = exp.to(torch.uint8).view(torch.int) + return new_x.transpose(1, 2).contiguous().transpose(1, 2) + + +@dataclass +class DeepGemmRunnerInput(RunnerInput): + hidden_states: torch.Tensor + hidden_states_scale: torch.Tensor + masked_m: torch.Tensor + expected_m: int + use_masked_gemm: bool + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.DEEP_GEMM + + +@dataclass +class DeepGemmRunnerOutput(RunnerOutput): + hidden_states: torch.Tensor + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.DEEP_GEMM + + +@dataclass +class DeepGemmMoeQuantInfo(MoeQuantInfo): + w13_weight: torch.Tensor + w2_weight: torch.Tensor + use_fp8: bool + w13_scale: Optional[torch.Tensor] = None + w2_scale: Optional[torch.Tensor] = None + block_shape: Optional[List[int]] = None + + +class DeepGemmRunnerCore(MoeRunnerCore): + def __init__(self, config: MoeRunnerConfig): + super().__init__(config) + assert self.config.activation == "silu" + + def run( + self, + runner_input: DeepGemmRunnerInput, + quant_info: DeepGemmMoeQuantInfo, + running_state: dict, + ) -> DeepGemmRunnerOutput: + + if runner_input.use_masked_gemm: + hidden_states = self._run_masked_gemm( + runner_input, + quant_info, + running_state, + ) + else: + hidden_states = self._run_contiguous_gemm( + runner_input, + quant_info, + running_state, + ) + return DeepGemmRunnerOutput(hidden_states=hidden_states) + + def _run_masked_gemm( + self, + runner_input: DeepGemmRunnerInput, + quant_info: DeepGemmMoeQuantInfo, + running_state: dict, + ) -> torch.Tensor: + + from sglang.srt.layers.moe.ep_moe.kernels import ( + silu_and_mul_masked_post_quant_fwd, + ) + from sglang.srt.layers.quantization import deep_gemm_wrapper + + hidden_states = runner_input.hidden_states + hidden_states_scale = runner_input.hidden_states_scale + masked_m = runner_input.masked_m + expected_m = runner_input.expected_m + + w13_weight = quant_info.w13_weight + w2_weight = quant_info.w2_weight + w13_scale = quant_info.w13_scale + w2_scale = quant_info.w2_scale + + hidden_states_device = running_state["hidden_states_device"] + + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: + b, s_mn, s_k = hidden_states_scale.shape + assert ( + s_mn % 4 == 0 and s_k % 4 == 0 + ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})" + + # GroupGemm-0 + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: + hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale) + else: + hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( + hidden_states_scale + ) + + num_groups, m, k = hidden_states.shape + n = w13_weight.size(1) + gateup_output = torch.empty( + (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (hidden_states, hidden_states_scale), + (w13_weight, w13_scale), + gateup_output, + masked_m, + expected_m, + ) + dispose_tensor(hidden_states) + + # Act + down_input = torch.empty( + ( + gateup_output.shape[0], + gateup_output.shape[1], + gateup_output.shape[2] // 2, + ), + device=hidden_states_device, + dtype=torch.float8_e4m3fn, + ) + scale_block_size = 128 + down_input_scale = torch.empty( + ( + gateup_output.shape[0], + gateup_output.shape[1], + gateup_output.shape[2] // 2 // scale_block_size, + ), + device=hidden_states_device, + dtype=torch.float32, + ) + silu_and_mul_masked_post_quant_fwd( + gateup_output, + down_input, + down_input_scale, + scale_block_size, + masked_m, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + del gateup_output + + # GroupGemm-1 + n = w2_weight.shape[1] + + if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: + down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor( + down_input_scale + ) + + down_output = torch.empty( + (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16 + ) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (down_input, down_input_scale), + (w2_weight, w2_scale), + down_output, + masked_m, + expected_m, + ) + del down_input + + return down_output + + def _run_contiguous_gemm( + self, + runner_input: DeepGemmRunnerInput, + quant_info: DeepGemmMoeQuantInfo, + running_state: dict, + ) -> torch.Tensor: + pass + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.DEEP_GEMM + + +@register_pre_permute("standard", "deep_gemm") +def pre_permute_standard_to_deep_gemm( + dispatch_output: StandardDispatchOutput, + quant_info: DeepGemmMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> DeepGemmRunnerInput: + from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess + + hidden_states, topk_output = dispatch_output + topk_weights, topk_ids, _ = topk_output + + hidden_states_shape = hidden_states.shape + hidden_states_dtype = hidden_states.dtype + hidden_states_device = hidden_states.device + hidden_states_ref = hidden_states + + topk_weights, topk_ids = topk_weights, topk_ids + + # PreReorder + masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = ( + moe_ep_deepgemm_preprocess( + topk_ids, + runner_config.num_local_experts, + hidden_states, + runner_config.top_k, + quant_info.block_shape, + ) + ) + + dispose_tensor(hidden_states_ref) + + running_state["topk_ids"] = topk_ids + running_state["topk_weights"] = topk_weights + running_state["hidden_states_shape"] = hidden_states_shape + running_state["hidden_states_dtype"] = hidden_states_dtype + running_state["hidden_states_device"] = hidden_states_device + running_state["src2dst"] = src2dst + + return DeepGemmRunnerInput( + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + masked_m=masked_m, + expected_m=expected_m, + use_masked_gemm=True, + ) + + +@register_post_permute("deep_gemm", "standard") +def post_permute_deep_gemm_to_standard( + runner_output: DeepGemmRunnerOutput, + quant_info: DeepGemmMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> StandardCombineInput: + from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput + + hidden_states_shape = running_state["hidden_states_shape"] + hidden_states_dtype = running_state["hidden_states_dtype"] + hidden_states_device = running_state["hidden_states_device"] + src2dst = running_state["src2dst"] + topk_ids = running_state["topk_ids"] + topk_weights = running_state["topk_weights"] + + output = torch.empty( + hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device + ) + post_reorder_triton_kernel[(hidden_states_shape[0],)]( + runner_output.hidden_states, + output, + src2dst, + topk_ids, + topk_weights, + runner_config.top_k, + hidden_states_shape[1], + BLOCK_SIZE=512, + ) + + dispose_tensor(runner_output.hidden_states) + + if runner_config.routed_scaling_factor is not None: + output *= runner_config.routed_scaling_factor + + return StandardCombineInput( + hidden_states=output, + ) diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py index 3b6fcd980..1e4ada79d 100644 --- a/python/sglang/srt/layers/moe/moe_runner/runner.py +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -9,6 +9,7 @@ from sglang.srt.layers.moe.moe_runner.base import ( MoeRunnerConfig, PermuteMethodPool, ) +from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore from sglang.srt.layers.moe.utils import get_moe_a2a_backend @@ -30,6 +31,8 @@ class MoeRunner: if runner_backend.is_triton(): self.runner_core = TritonRunnerCore(config) + elif runner_backend.is_deep_gemm(): + self.runner_core = DeepGemmRunnerCore(config) else: raise NotImplementedError(f"Unsupported runner backend: {runner_backend}") diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index a70d1be40..624249f4a 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -44,6 +44,7 @@ class MoeA2ABackend(Enum): class MoeRunnerBackend(Enum): AUTO = "auto" + DEEP_GEMM = "deep_gemm" TRITON = "triton" TRITON_KERNEL = "triton_kernel" FLASHINFER_TRTLLM = "flashinfer_trtllm" @@ -54,6 +55,9 @@ class MoeRunnerBackend(Enum): def is_auto(self): return self == MoeRunnerBackend.AUTO + def is_deep_gemm(self): + return self == MoeRunnerBackend.DEEP_GEMM + def is_triton(self): return self == MoeRunnerBackend.TRITON @@ -147,7 +151,9 @@ def get_moe_a2a_backend() -> MoeA2ABackend: def get_moe_runner_backend() -> MoeRunnerBackend: global MOE_RUNNER_BACKEND if MOE_RUNNER_BACKEND is None: - logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend") + logger.warning( + "MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected" + ) MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO return MOE_RUNNER_BACKEND diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index d14e9b18e..a1a25102d 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -31,8 +31,8 @@ except ImportError: from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo -from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker from sglang.srt.layers.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, @@ -1006,8 +1006,29 @@ class Fp8MoEMethod(FusedMoEMethodBase): def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): + + from sglang.srt.layers.moe.utils import ( + get_moe_a2a_backend, + get_moe_runner_backend, + ) + from sglang.srt.layers.quantization import deep_gemm_wrapper + self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + moe_runner_backend = get_moe_runner_backend() + + if moe_runner_backend.is_auto(): + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and get_moe_a2a_backend().is_deepep() + ): + moe_runner_backend = MoeRunnerBackend.DEEP_GEMM + else: + moe_runner_backend = MoeRunnerBackend.TRITON + if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton(): + self.runner = MoeRunner(moe_runner_backend, moe_runner_config) + else: + # TODO(cwan): refactor other backends + pass def apply( self, @@ -1087,22 +1108,67 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) return StandardCombineInput(hidden_states=output) - quant_info = TritonMoeQuantInfo( - w13_weight=layer.w13_weight, - w2_weight=layer.w2_weight, - use_fp8_w8a8=True, - w13_scale=( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ), - w2_scale=( - layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale - ), - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - ) + if self.runner.runner_backend.is_deep_gemm(): + + w13_weight = layer.w13_weight + w2_weight = layer.w2_weight + + if self.block_quant: + block_shape = self.quant_config.weight_block_size + w13_scale = layer.w13_weight_scale_inv + w2_scale = layer.w2_weight_scale_inv + else: + # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm + scale_block_size = 128 + block_shape = [scale_block_size, scale_block_size] + w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1 + w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1 + w13_scale = ( + layer.w13_weight_scale.unsqueeze(1) + .repeat_interleave(w13_scale_n, dim=1) + .unsqueeze(2) + .repeat_interleave(w13_scale_k, dim=2) + ) + w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1 + w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1 + w2_scale = ( + layer.w2_weight_scale.unsqueeze(1) + .repeat_interleave(w2_scale_n, dim=1) + .unsqueeze(2) + .repeat_interleave(w2_scale_k, dim=2) + ) + quant_info = DeepGemmMoeQuantInfo( + w13_weight=w13_weight, + w2_weight=w2_weight, + use_fp8=True, + w13_scale=w13_scale, + w2_scale=w2_scale, + block_shape=block_shape, + ) + elif self.runner.runner_backend.is_triton(): + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + w13_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) + else: + raise NotImplementedError( + "Unsupported runner backend: %s" % self.runner.runner_backend + ) + return self.runner.run(dispatch_output, quant_info) def apply_with_router_logits( diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index 158ae6561..fb85f0b31 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -21,7 +21,6 @@ from sglang.srt.utils import is_npu, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe import MoeRunnerConfig - from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.token_dispatcher import ( CombineInput, StandardDispatchOutput, @@ -94,9 +93,7 @@ class W4AFp8Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase - from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - from sglang.srt.managers.schedule_batch import global_server_args_dict if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): @@ -133,7 +130,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): def create_weights( self, - layer: EPMoE, + layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -292,7 +289,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): def apply( self, - layer: EPMoE, + layer: Module, dispatch_output: StandardDispatchOutput, ) -> CombineInput: @@ -303,18 +300,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): topk_output = dispatch_output.topk_output topk_weights, topk_ids, _ = topk_output - local_topk_ids = topk_ids - if get_moe_expert_parallel_world_size() > 1: - local_topk_ids = torch.where( - topk_ids == -1, - layer.num_experts, - topk_ids, - ) output = cutlass_w4a8_moe( - layer.start_expert_id, - layer.end_expert_id, - layer.num_experts, x, layer.w13_weight, layer.w2_weight, @@ -322,7 +309,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale_inv, topk_weights, topk_ids, - local_topk_ids, self.a_strides1, self.b_strides1, self.c_strides1, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index a35420993..652cb465d 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -49,7 +49,6 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.router import fused_moe_router_shim from sglang.srt.layers.moe.topk import TopK @@ -176,17 +175,7 @@ class Grok1MoE(nn.Module): custom_routing_function=custom_routing_function, ) - kwargs = {} - if get_moe_expert_parallel_world_size() > 1: - MoEImpl = EPMoE - else: - MoEImpl = FusedMoE - kwargs["reduce_results"] = reduce_results - kwargs["use_presharded_weights"] = use_presharded_weights - kwargs["inplace"] = inplace - kwargs["no_combine"] = no_combine - - self.experts = MoEImpl( + self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, layer_id=layer_id, @@ -195,7 +184,10 @@ class Grok1MoE(nn.Module): params_dtype=params_dtype, quant_config=quant_config, activation="gelu", - **kwargs, + reduce_results=reduce_results, + use_presharded_weights=use_presharded_weights, + inplace=inplace, + no_combine=no_combine, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index c5f04a4fc..81026f9bb 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -36,7 +36,6 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -94,8 +93,7 @@ class MixtralMoE(nn.Module): renormalize=True, ) - MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE - self.experts = MoEImpl( + self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, layer_id=layer_id, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8c955aabc..7643fd6b3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -121,6 +121,17 @@ NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter" RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"] +MOE_RUNNER_BACKEND_CHOICES = [ + "auto", + "deep_gemm", + "triton", + "triton_kernel", + "flashinfer_trtllm", + "flashinfer_cutlass", + "flashinfer_mxfp4", + "flashinfer_cutedsl", +] + # Allow external code to add more choices def add_load_format_choices(choices): @@ -143,6 +154,10 @@ def add_grammar_backend_choices(choices): GRAMMAR_BACKEND_CHOICES.extend(choices) +def add_moe_runner_backend_choices(choices): + MOE_RUNNER_BACKEND_CHOICES.extend(choices) + + def add_deterministic_attention_backend_choices(choices): DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices) @@ -315,14 +330,7 @@ class ServerArgs: # Expert parallelism ep_size: int = 1 moe_a2a_backend: Literal["none", "deepep"] = "none" - moe_runner_backend: Literal[ - "auto", - "triton", - "triton_kernel", - "flashinfer_trtllm", - "flashinfer_cutlass", - "flashinfer_mxfp4", - ] = "auto" + moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" @@ -2191,15 +2199,7 @@ class ServerArgs: parser.add_argument( "--moe-runner-backend", type=str, - choices=[ - "auto", - "triton", - "triton_kernel", - "flashinfer_trtllm", - "flashinfer_cutlass", - "flashinfer_mxfp4", - "flashinfer_cutedsl", - ], + choices=MOE_RUNNER_BACKEND_CHOICES, default=ServerArgs.moe_runner_backend, help="Choose the runner backend for MoE.", ) diff --git a/python/sglang/test/test_block_fp8_ep.py b/python/sglang/test/test_block_fp8_ep.py deleted file mode 100644 index 670f2e0f8..000000000 --- a/python/sglang/test/test_block_fp8_ep.py +++ /dev/null @@ -1,358 +0,0 @@ -import itertools -import random -import unittest -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch - -from sglang.srt.layers.moe.ep_moe.kernels import ( - grouped_gemm_triton, - post_reorder_triton_kernel, - pre_reorder_triton_kernel, - run_moe_ep_preproess, - silu_and_mul_triton_kernel, -) -from sglang.srt.layers.moe.topk import TopKConfig, select_experts -from sglang.test.test_utils import CustomTestCase - - -# For test -def ep_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - router_logits: torch.Tensor, - topk_config: TopKConfig, - # ep config - num_experts: int = 256, - fp8_dtype: torch.types = torch.float8_e4m3fn, - num_experts_per_partition: int = 128, - start_expert_id: int = 0, - end_expert_id: int = 127, - use_fp8_w8a8: bool = False, - w1_scale_inv: Optional[torch.Tensor] = None, - w2_scale_inv: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, -): - use_blockwise_fp8 = block_shape is not None - top_k = topk_config.top_k - topk_output = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - topk_config=topk_config, - ) - topk_weights, topk_ids, _ = topk_output - - reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts) - - gateup_input = torch.empty( - (int(hidden_states.shape[0] * top_k), hidden_states.shape[1]), - device=hidden_states.device, - dtype=( - fp8_dtype - if (use_fp8_w8a8 and not use_blockwise_fp8) - else hidden_states.dtype - ), - ) - - if use_fp8_w8a8 and not use_blockwise_fp8: - max_value = ( - torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32) - ) - w1_input_scale = max_value / torch.finfo(fp8_dtype).max - else: - w1_input_scale = None - - # PreReorder - pre_reorder_triton_kernel[(hidden_states.shape[0],)]( - hidden_states, - gateup_input, - src2dst, - topk_ids, - w1_input_scale, - start_expert_id, - end_expert_id, - top_k, - hidden_states.shape[1], - BLOCK_SIZE=512, - use_per_token_if_dynamic=True, - ) - - seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2] - weight_indices_cur_rank = torch.arange( - 0, - num_experts_per_partition, - device=hidden_states.device, - dtype=torch.int64, - ) - - # GroupGemm-0 - gateup_output = torch.empty( - gateup_input.shape[0], - w1.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - gateup_output = grouped_gemm_triton( - a=gateup_input, - b=w1, - c=gateup_output, - batch_size=num_experts_per_partition, - weight_column_major=True, - seg_indptr=seg_indptr_cur_rank, - weight_indices=weight_indices_cur_rank, - use_fp8_w8a8=use_fp8_w8a8, - scale_a=w1_input_scale, - scale_b=w1_scale_inv, - block_shape=block_shape, - ) - - # Act - down_input = torch.empty( - gateup_output.shape[0], - gateup_output.shape[1] // 2, - device=gateup_output.device, - dtype=( - fp8_dtype - if (use_fp8_w8a8 and not use_blockwise_fp8) - else hidden_states.dtype - ), - ) - if use_fp8_w8a8 and not use_blockwise_fp8: - w2_input_scale = torch.ones( - num_experts_per_partition, - dtype=torch.float32, - device=hidden_states.device, - ) - else: - w2_input_scale = None - - silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( - gateup_output, - down_input, - gateup_output.shape[1], - reorder_topk_ids, - w2_input_scale, - start_expert_id, - end_expert_id, - BLOCK_SIZE=512, - ) - - # GroupGemm-1 - down_output = torch.empty( - down_input.shape[0], - w2.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - down_output = grouped_gemm_triton( - a=down_input, - b=w2, - c=down_output, - batch_size=num_experts_per_partition, - weight_column_major=True, - seg_indptr=seg_indptr_cur_rank, - weight_indices=weight_indices_cur_rank, - use_fp8_w8a8=use_fp8_w8a8, - scale_a=w2_input_scale, - scale_b=w2_scale_inv, - block_shape=block_shape, - ) - - # PostReorder - output = torch.empty_like(hidden_states) - post_reorder_triton_kernel[(hidden_states.size(0),)]( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - top_k, - hidden_states.size(1), - 0, - BLOCK_SIZE=512, - ) - return output - - -# test util -def block_dequant( - x_q_block: torch.Tensor, - x_s: torch.Tensor, - block_size: List[int], -) -> Tuple[torch.Tensor, torch.Tensor]: - """This function converts block-wise quantization to tensor-wise quantization. - The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale - and the block size. - The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. - Note only float8 is supported for now. - """ - - # process 3D tensor - if x_q_block.dim() == 3: - batch_size = x_q_block.size(0) - return torch.stack( - [block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)] - ) - - block_n, block_k = block_size[0], block_size[1] - n, k = x_q_block.shape - n_tiles = (n + block_n - 1) // block_n - k_tiles = (k + block_k - 1) // block_k - assert n_tiles == x_s.shape[0] - assert k_tiles == x_s.shape[1] - - x_dq_block = x_q_block.to(torch.float32) - - x_dq_block_tiles = [ - [ - x_dq_block[ - j * block_n : min((j + 1) * block_n, n), - i * block_k : min((i + 1) * block_k, k), - ] - for i in range(k_tiles) - ] - for j in range(n_tiles) - ] - - for i in range(k_tiles): - for j in range(n_tiles): - x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] - - return x_dq_block - - -class TestW8A8BlockFP8EPMoE(CustomTestCase): - DTYPES = [torch.half, torch.bfloat16] - M = [1, 222, 1024, 2048] - N = [128, 1024, 2048] - K = [256, 4096, 5120] - E = [8, 16] - ep_size = [2, 4] - TOP_KS = [2, 4] - BLOCK_SIZE = [[128, 128]] - SEEDS = [0] - - @classmethod - def setUpClass(cls): - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA is not available") - torch.set_default_device("cuda") - - def _w8a8_block_fp8_ep_moe( - self, M, N, K, E, ep_size, topk, block_size, dtype, seed - ): - torch.manual_seed(seed) - random.seed(seed) - # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max - w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max - w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = ( - torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - * factor_for_scale - ) - w2_s = ( - torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - * factor_for_scale - ) - - w1_ref = block_dequant(w1, w1_s, block_size).to(dtype) - w2_ref = block_dequant(w2, w2_s, block_size).to(dtype) - - score = torch.randn((M, E), dtype=dtype) - num_experts_per_partition = E // ep_size - cur_rank = random.randint(0, ep_size - 1) - start_id = cur_rank * num_experts_per_partition - end_id = start_id + num_experts_per_partition - 1 - - topk_config = TopKConfig( - top_k=topk, - renormalize=False, - ) - - with torch.inference_mode(): - out = ep_moe( - hidden_states=a, - w1=w1, - w2=w2, - router_logits=score, - topk_config=topk_config, - use_fp8_w8a8=True, - w1_scale_inv=w1_s, - w2_scale_inv=w2_s, - block_shape=block_size, - num_experts=E, - num_experts_per_partition=num_experts_per_partition, - start_expert_id=start_id, - end_expert_id=end_id, - ) - ref_out = ep_moe( - hidden_states=a, - w1=w1_ref, - w2=w2_ref, - router_logits=score, - topk_config=topk_config, - use_fp8_w8a8=False, - w1_scale_inv=None, - w2_scale_inv=None, - block_shape=None, - num_experts=E, - num_experts_per_partition=num_experts_per_partition, - start_expert_id=start_id, - end_expert_id=end_id, - ) - self.assertTrue( - torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) - / (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6) - < 0.06 - ) - - def test_w8a8_block_fp8_ep_moe(self): - for params in itertools.product( - self.M, - self.N, - self.K, - self.E, - self.ep_size, - self.TOP_KS, - self.BLOCK_SIZE, - self.DTYPES, - self.SEEDS, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - E=params[3], - ep_size=params[4], - topk=params[5], - block_size=params[6], - dtype=params[7], - seed=params[8], - ): - self._w8a8_block_fp8_ep_moe(*params) - torch.cuda.empty_cache() - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/python/sglang/test/test_cutlass_w4a8_moe.py b/python/sglang/test/test_cutlass_w4a8_moe.py index 6706fc962..7d96cccd5 100644 --- a/python/sglang/test/test_cutlass_w4a8_moe.py +++ b/python/sglang/test/test_cutlass_w4a8_moe.py @@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty ) topk_weights, topk_ids, _ = topk_output expert_map = torch.arange(E, dtype=torch.int32, device=device) - expert_map[local_e:] = E + expert_map[local_e:] = -1 output = cutlass_moe( a, @@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty c_strides2, s_strides13, s_strides2, - 0, - local_e - 1, - E, + local_e, a1_scale, a2_scale, expert_map, @@ -178,7 +176,7 @@ def cutlass_moe( w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids_: torch.Tensor, + topk_ids: torch.Tensor, a_strides1: torch.Tensor, b_strides1: torch.Tensor, c_strides1: torch.Tensor, @@ -187,40 +185,32 @@ def cutlass_moe( c_strides2: torch.Tensor, s_strides13: torch.Tensor, s_strides2: torch.Tensor, - start_expert_id: int, - end_expert_id: int, - E: int, + num_local_experts: int, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ): - local_topk_ids = topk_ids_ - local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E) + topk_ids = expert_map[topk_ids] device = a.device - local_num_experts = end_expert_id - start_expert_id + 1 expert_offsets = torch.empty( - (local_num_experts + 1), dtype=torch.int32, device=device + (num_local_experts + 1), dtype=torch.int32, device=device ) problem_sizes1 = torch.empty( - (local_num_experts, 3), dtype=torch.int32, device=device + (num_local_experts, 3), dtype=torch.int32, device=device ) problem_sizes2 = torch.empty( - (local_num_experts, 3), dtype=torch.int32, device=device + (num_local_experts, 3), dtype=torch.int32, device=device ) return cutlass_w4a8_moe( - start_expert_id, - end_expert_id, - E, a, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids_, - local_topk_ids, + topk_ids, a_strides1, b_strides1, c_strides1, diff --git a/test/srt/ep/test_moe_ep.py b/test/srt/ep/test_moe_ep.py index 7456c9329..74a5790d4 100644 --- a/test/srt/ep/test_moe_ep.py +++ b/test/srt/ep/test_moe_ep.py @@ -12,7 +12,7 @@ from sglang.test.test_utils import ( ) -class TestEpMoE(CustomTestCase): +class TestEp(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -34,18 +34,6 @@ class TestEpMoE(CustomTestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.5) - def test_mgsm_en(self): args = SimpleNamespace( base_url=self.base_url, @@ -59,7 +47,7 @@ class TestEpMoE(CustomTestCase): self.assertGreaterEqual(metrics["score"], 0.8) -class TestEpMoEFP8(CustomTestCase): +class TestEpDeepGEMM(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -76,6 +64,8 @@ class TestEpMoEFP8(CustomTestCase): "2", "--quantization", "fp8", + "--moe-runner-backend", + "deep_gemm", ], ) @@ -83,18 +73,6 @@ class TestEpMoEFP8(CustomTestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.5) - def test_mgsm_en(self): args = SimpleNamespace( base_url=self.base_url, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a2ec504cd..186f7c260 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -130,6 +130,7 @@ suites = { TestFile("test_modelopt_loader.py", 30), ], "per-commit-2-gpu": [ + TestFile("ep/test_moe_ep.py", 140), TestFile("lora/test_lora_tp.py", 116), TestFile("rl/test_update_weights_from_distributed.py", 103), TestFile("test_data_parallelism.py", 73),