From 924ca7c92c86fa3a6a321e7944e2fdd193f30c50 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:59:29 +0800 Subject: [PATCH] Add DeepSeek V3/R1 shared experts fusion (#4918) --- .../tuning_fused_moe_triton.py | 13 +- python/sglang/bench_serving.py | 37 ++++- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ .../layers/moe/fused_moe_triton/fused_moe.py | 18 ++- python/sglang/srt/layers/moe/topk.py | 52 ++++++- .../srt/layers/quantization/__init__.py | 3 +- .../compressed_tensors/compressed_tensors.py | 3 +- .../compressed_tensors_moe.py | 44 ++++-- python/sglang/srt/managers/schedule_batch.py | 2 + .../sglang/srt/model_executor/model_runner.py | 2 + python/sglang/srt/models/deepseek_v2.py | 86 ++++++++++- python/sglang/srt/server_args.py | 18 +++ sgl-kernel/tests/test_moe_align.py | 2 +- 14 files changed, 536 insertions(+), 36 deletions(-) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 89309ac6e..6f406e7ce 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -399,7 +399,12 @@ def main(args: argparse.Namespace): intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: - E = config.n_routed_experts + n_share_fusion_experts = args.n_share_experts_fusion + E = ( + config.n_routed_experts + n_share_fusion_experts + if config.architectures[0] in ["DeepseekV3ForCausalLM"] + else config.n_routed_experts + ) topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size @@ -559,6 +564,12 @@ if __name__ == "__main__": parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true") + parser.add_argument( + "--n-share-experts-fusion", + type=int, + default=0, + help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1", + ) args = parser.parse_args() main(args) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 6a8d4d00a..f6e03a308 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -993,13 +993,16 @@ async def benchmark( return await request_func(request_func_input=request_func_input, pbar=pbar) # Warmup - print("Starting initial single prompt test run...") + print(f"Starting warmup with {args.warmup_requests} sequences...") + + # Use the first request for all warmup iterations test_prompt, test_prompt_len, test_output_len = input_requests[0] if lora_names != None and len(lora_names) != 0: lora_name = lora_names[0] else: lora_name = None + # Create the test input once test_input = RequestFuncInput( model=model_id, prompt=test_prompt, @@ -1009,14 +1012,26 @@ async def benchmark( lora_name=lora_name, extra_request_body=extra_request_body, ) - test_output = await request_func(request_func_input=test_input) - if not test_output.success: + + # Run warmup requests + warmup_tasks = [] + for _ in range(args.warmup_requests): + warmup_tasks.append( + asyncio.create_task(request_func(request_func_input=test_input)) + ) + + warmup_outputs = await asyncio.gather(*warmup_tasks) + + # Check if at least one warmup request succeeded + if not any(output.success for output in warmup_outputs): raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}" + "Warmup failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {warmup_outputs[0].error}" ) else: - print("Initial test run completed. Starting main benchmark run...") + print( + f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..." + ) # Flush cache if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache: @@ -1253,6 +1268,10 @@ def run_benchmark(args_: argparse.Namespace): if not hasattr(args, "max_concurrency"): args.max_concurrency = None + # Set default value for warmup_requests if not present + if not hasattr(args, "warmup_requests"): + args.warmup_requests = 1 + print(f"benchmark_args={args}") # Set global environments @@ -1560,6 +1579,12 @@ if __name__ == "__main__": action="store_true", help="Flush the cache before running the benchmark", ) + parser.add_argument( + "--warmup-requests", + type=int, + default=1, + help="Number of warmup requests to run before the benchmark", + ) group = parser.add_argument_group("generated-shared-prefix dataset arguments") group.add_argument( diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000..453d04c6a --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000..c7726d87d --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 7a5a6d3cd..7c9ead9ce 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -13,11 +13,6 @@ import triton import triton.language as tl from sglang.srt.layers.moe.topk import select_experts -from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.layers.quantization.int8_kernel import ( - per_token_group_quant_int8, - per_token_quant_int8, -) from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, @@ -42,9 +37,6 @@ if _is_cuda: from sgl_kernel import gelu_and_mul, silu_and_mul from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant - from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_fp8, - ) else: from vllm import _custom_ops as vllm_ops @@ -764,6 +756,16 @@ def invoke_fused_moe_kernel( block_shape: Optional[List[int]] = None, no_combine: bool = False, ) -> None: + from sglang.srt.layers.quantization.int8_kernel import ( + per_token_group_quant_int8, + per_token_quant_int8, + ) + + if _is_cuda: + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) + assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 29984f3f2..53c36c63a 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -12,12 +12,14 @@ # limitations under the License. # ============================================================================== +import os from typing import Callable, Optional import torch import torch.nn.functional as F from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() @@ -102,11 +104,13 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + n_share_experts_fusion: int = 0, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] + num_experts = scores.shape[1] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values ) # [n, n_group] @@ -122,9 +126,25 @@ def grouped_topk( ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + if n_share_experts_fusion: + topk_ids[:, -1] = torch.randint( + low=num_experts, + high=num_experts + n_share_experts_fusion, + size=(topk_ids.size(0),), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / 2.5 + ) # 2.5 is the routed_scaling_factor. if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights_sum = ( + topk_weights.sum(dim=-1, keepdim=True) + if n_share_experts_fusion == 0 + else topk_weights[:, :-1].sum(dim=-1, keepdim=True) + ) + topk_weights = topk_weights / topk_weights_sum return topk_weights.to(torch.float32), topk_ids.to(torch.int32) @@ -137,11 +157,13 @@ def biased_grouped_topk_impl( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + n_share_experts_fusion: int = 0, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" scores = gating_output.sigmoid() num_token = scores.shape[0] + num_experts = scores.shape[1] scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(num_token, num_expert_group, -1) @@ -164,8 +186,25 @@ def biased_grouped_topk_impl( _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) topk_weights = scores.gather(1, topk_ids) + if n_share_experts_fusion: + topk_ids[:, -1] = torch.randint( + low=num_experts, + high=num_experts + n_share_experts_fusion, + size=(topk_ids.size(0),), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / 2.5 + ) # 2.5 is the routed_scaling_factor. + if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights_sum = ( + topk_weights.sum(dim=-1, keepdim=True) + if n_share_experts_fusion == 0 + else topk_weights[:, :-1].sum(dim=-1, keepdim=True) + ) + topk_weights = topk_weights / topk_weights_sum return topk_weights.to(torch.float32), topk_ids.to(torch.int32) @@ -179,6 +218,7 @@ def biased_grouped_topk( num_expert_group: int = 0, topk_group: int = 0, compiled: bool = True, + n_share_experts_fusion: int = 0, ): biased_grouped_topk_fn = ( torch.compile( @@ -195,6 +235,7 @@ def biased_grouped_topk( renormalize, num_expert_group, topk_group, + n_share_experts_fusion=n_share_experts_fusion, ) @@ -210,7 +251,10 @@ def select_experts( correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, ): - # DeekSeekv2 uses grouped_top_k + n_share_experts_fusion = 0 + if global_server_args_dict["n_share_experts_fusion"] is not None: + n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] + # DeekSeek V2/V3/R1 serices models uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None @@ -222,6 +266,7 @@ def select_experts( renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group, + n_share_experts_fusion=n_share_experts_fusion, ) else: topk_weights, topk_ids = biased_grouped_topk( @@ -232,6 +277,7 @@ def select_experts( renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group, + n_share_experts_fusion=n_share_experts_fusion, ) elif torch_native and custom_routing_function is None: topk_weights, topk_ids = fused_topk_native( diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index bee35c9c7..3152e265f 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -51,7 +51,6 @@ except ImportError: from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config @@ -203,6 +202,8 @@ def get_linear_quant_method( def gptq_get_quant_method(self, layer, prefix): + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index e056ce95f..c60e09be4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -23,7 +23,6 @@ from sglang.srt.layers.linear import ( LinearMethodBase, UnquantizedLinearMethod, ) -from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig): return UnquantizedLinearMethod() layer.scheme = scheme return CompressedTensorsLinearMethod(self) + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + if isinstance(layer, FusedMoE): return CompressedTensorsMoEMethod.get_moe_method(self) return None diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 032ff8b60..7e5b3231f 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -4,18 +4,19 @@ import enum import logging from enum import Enum -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy -from sglang.srt.layers.moe.fused_moe_triton import ( - FusedMoE, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, -) -from sglang.srt.layers.moe.topk import select_experts +if TYPE_CHECKING: + from sglang.srt.layers.moe.fused_moe_triton import ( + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, + ) + from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import ( all_close_1d, @@ -55,7 +56,13 @@ __all__ = [ ] -class CompressedTensorsMoEMethod(FusedMoEMethodBase): +class CompressedTensorsMoEMethod: + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase + + if cls is CompressedTensorsMoEMethod: + return super().__new__(cls) + return super().__new__(cls) @staticmethod def get_moe_method( @@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__( self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 ): + from sglang.srt.layers.moe.fused_moe_triton import ( + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, + ) + self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( @@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): params_dtype: torch.dtype, **extra_weight_attrs, ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported params_dtype = torch.float8_e4m3fn @@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): scoring_func: str = "softmax", correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + inplace: bool = True, + no_combine: bool = False, ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.fused_moe_triton import fused_experts + from sglang.srt.layers.moe.topk import select_experts topk_weights, topk_ids = select_experts( hidden_states=x, @@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=inplace, activation=activation, use_fp8_w8a8=True, w1_scale=layer.w13_weight_scale, @@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__( self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 ): + from sglang.srt.layers.moe.fused_moe_triton import ( + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, + ) + self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.moe.topk import select_experts + assert activation == "silu", "Only SiLU activation is supported." if not VLLM_AVAILABLE: raise ImportError( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a8796cb42..2c47a6ed2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -81,6 +81,8 @@ global_server_args_dict = { "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "chunked_prefill_size": ServerArgs.chunked_prefill_size, + "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, + "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, } logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f42ea02d5..0f345daba 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -157,6 +157,8 @@ class ModelRunner: "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, + "n_share_experts_fusion": server_args.n_share_experts_fusion, + "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion, } ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2f78de492..8b859bfdc 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -16,12 +16,14 @@ # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py """Inference-only DeepseekV2 model.""" +import logging import os from typing import Any, Dict, Iterable, Optional, Tuple import torch import torch.nn.functional as F from torch import nn +from tqdm import tqdm from transformers import PretrainedConfig from sglang.srt.distributed import ( @@ -87,6 +89,8 @@ if _is_hip: expert_distribution_recorder = ExpertDistributionRecorder() +logger = logging.getLogger(__name__) + class DeepseekV2MLP(nn.Module): def __init__( @@ -168,6 +172,12 @@ class DeepseekV2MoE(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts + self.n_share_experts_fusion = ( + global_server_args_dict["n_share_experts_fusion"] + if global_server_args_dict["n_share_experts_fusion"] is not None + else 0 + ) + self.routed_scaling_factor = config.routed_scaling_factor if self.tp_size > config.n_routed_experts: raise ValueError( @@ -188,9 +198,10 @@ class DeepseekV2MoE(nn.Module): if global_server_args_dict["enable_deepep_moe"] else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) ) + self.experts = MoEImpl( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, + num_experts=config.n_routed_experts + self.n_share_experts_fusion, + top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, renormalize=config.norm_topk_prob, @@ -207,7 +218,7 @@ class DeepseekV2MoE(nn.Module): ), ) - if config.n_shared_experts is not None: + if config.n_shared_experts is not None and self.n_share_experts_fusion == 0: intermediate_size = config.moe_intermediate_size * config.n_shared_experts # disable tp for shared experts when enable deepep moe if not global_server_args_dict["enable_deepep_moe"]: @@ -267,8 +278,10 @@ class DeepseekV2MoE(nn.Module): return self.forward_deepep(hidden_states, forward_mode) def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.n_shared_experts is not None: + if self.n_shared_experts is not None and self.n_share_experts_fusion == 0: shared_output = self.shared_experts(hidden_states) + else: + shared_output = None # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) final_hidden_states = ( @@ -1315,7 +1328,28 @@ class DeepseekV2ForCausalLM(nn.Module): ) -> None: super().__init__() self.config = config + self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config + self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] + # Only Deepseek V3/R1 can use shared experts fusion optimization now. + if ( + global_server_args_dict.get("disable_shared_experts_fusion", False) + or self.config.architectures[0] != "DeepseekV3ForCausalLM" + or self.config.n_routed_experts != 256 + or self.config.routed_scaling_factor != 2.5 + ): + self.n_share_experts_fusion = None + global_server_args_dict["n_share_experts_fusion"] = None + logger.info( + "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled." + ) + elif self.n_share_experts_fusion is None: + global_server_args_dict["n_share_experts_fusion"] = self.tp_size + self.n_share_experts_fusion = self.tp_size + logger.info( + f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion." + ) + self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) ) @@ -1352,6 +1386,43 @@ class DeepseekV2ForCausalLM(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0: + weights_list = list(weights) + weights_dict = dict(weights_list) + suffix_list = [ + "down_proj.weight", + "down_proj.weight_scale_inv", + "gate_proj.weight", + "gate_proj.weight_scale_inv", + "up_proj.weight", + "up_proj.weight_scale_inv", + ] + names_to_remove = [] + for moe_layer in tqdm( + range( + self.config.first_k_dense_replace, + self.config.num_hidden_layers, + self.config.moe_layer_freq, + ), + desc=f"Cloning {self.n_share_experts_fusion} " + "replicas of the shared expert into MoE", + ): + for num_repeat in range(self.n_share_experts_fusion): + for suffix in suffix_list: + shared_expert_weight_name = ( + f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}" + ) + weights_list.append( + ( + f"model.layers.{moe_layer}." + f"mlp.experts." + f"{self.config.n_routed_experts + num_repeat}" + f".{suffix}", + weights_dict[shared_expert_weight_name].clone(), + ) + ) + names_to_remove += [shared_expert_weight_name] + weights = [w for w in weights_list if w[0] not in names_to_remove] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -1364,7 +1435,12 @@ class DeepseekV2ForCausalLM(nn.Module): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, + num_experts=self.config.n_routed_experts + + ( + self.n_share_experts_fusion + if self.n_share_experts_fusion is not None + else 0 + ), ) params_dict = dict(self.named_parameters()) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3bdb7de9c..54b532b4a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -183,6 +183,8 @@ class ServerArgs: enable_flashmla: bool = False flashinfer_mla_disable_ragged: bool = False warmups: Optional[str] = None + n_share_experts_fusion: Optional[int] = None + disable_shared_experts_fusion: bool = False # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -224,6 +226,9 @@ class ServerArgs: # GPU memory is not known yet or no GPU is available. gpu_mem = None + if is_hip(): + self.disable_shared_experts_fusion = True + # Set mem fraction static, which depends on the tensor parallelism size if self.mem_fraction_static is None: if self.tp_size >= 16: @@ -1102,6 +1107,19 @@ class ServerArgs: help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", ) + parser.add_argument( + "--n-share-experts-fusion", + type=int, + default=None, + help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " + "we use tp_size by default.", + ) + parser.add_argument( + "--disable-shared-experts-fusion", + action="store_true", + help="Disable shared experts fusion by setting n_share_experts_fusion to 0.", + ) + # Server warmups parser.add_argument( "--warmups", diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 3d89c3406..fb7c4c640 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -144,7 +144,7 @@ def moe_align_block_size_triton( [32, 64, 128, 256], # block_size [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens [1, 2, 4, 8, 16, 32, 64], # topk - [64, 160, 256], # num_experts + [64, 160, 256, 257, 260, 264], # num_experts ) ), )