From acc816d8a24ebe6c5ccd778f95c9a406d12b13f5 Mon Sep 17 00:00:00 2001 From: lukec <118525388+sleepcoo@users.noreply.github.com> Date: Thu, 8 May 2025 16:20:32 +0800 Subject: [PATCH] DeepEP normal support deepgemm-contiguous (#5626) Co-authored-by: Yingyi Huang Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: Xuting Zhou Co-authored-by: ZhengHSI --- .../sglang/srt/layers/moe/ep_moe/kernels.py | 342 +++++++++++++++++- python/sglang/srt/layers/moe/ep_moe/layer.py | 121 ++++++- .../srt/layers/moe/ep_moe/token_dispatcher.py | 151 +++++--- .../srt/layers/quantization/deep_gemm.py | 5 + .../srt/layers/quantization/fp8_kernel.py | 4 +- python/sglang/srt/models/deepseek_v2.py | 4 + 6 files changed, 568 insertions(+), 59 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index ab7350555..37c87ed0c 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -5,16 +5,23 @@ import torch import triton import triton.language as tl -from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import is_cuda +logger = logging.getLogger(__name__) + _is_cuda = is_cuda() if _is_cuda: from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8, ) -logger = logging.getLogger(__name__) + + try: + from deep_gemm import ceil_div + except ImportError: + logger.error(f"Failed to import ceil_div from deep_gemm.") + +import triton.language as tl @triton.jit @@ -704,3 +711,334 @@ def grouped_gemm_triton( **config, ) return c + + +@triton.jit +def _fwd_kernel_ep_scatter_1( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts: tl.constexpr, + BLOCK_E: tl.constexpr, + BLOCK_EXPERT_NUM: tl.constexpr, +): + cur_expert = tl.program_id(0) + + offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) + tokens_per_expert = tl.load( + num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0, + ) + cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert + tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) + + cur_expert_start = tl.load(expert_start_loc + cur_expert) + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) + + m_indices_start_ptr = m_indices + cur_expert_start + off_expert = tl.arange(0, BLOCK_E) + + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + tl.store( + m_indices_start_ptr + start_m + off_expert, + cur_expert, + ) + + +@triton.jit +def _fwd_kernel_ep_scatter_2( + total_token_num, + expert_start_loc, + recv_x, + recv_x_stride0, + recv_x_stride1, + recv_x_scale, + recv_x_scale_stride0, + recv_x_scale_stride1, + recv_topk, + recv_topk_stride0, + recv_topk_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + output_tensor_scale, + output_tensor_scale_stride0, + output_tensor_scale_stride1, + output_index, + output_index_stride0, + output_index_stride1, + topk_num: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + HIDDEN_SIZE_PAD: tl.constexpr, + SCALE_HIDDEN_SIZE: tl.constexpr, + SCALE_HIDDEN_SIZE_PAD: tl.constexpr, +): + start_token_id = tl.program_id(0) + grid_num = tl.num_programs(0) + + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) + mask = offset_in < HIDDEN_SIZE + + offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) + mask_s = offset_in_s < SCALE_HIDDEN_SIZE + + for token_id in range(start_token_id, total_token_num, grid_num): + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) + to_copy_s = tl.load( + recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s + ) + + for topk_index in tl.range(0, topk_num, 1, num_stages=4): + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) + if expert_id >= 0: + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) + tl.store( + output_index + token_id * output_index_stride0 + topk_index, + dest_token_index, + ) + output_tensor_ptr = ( + output_tensor + dest_token_index * output_tensor_stride0 + ) + output_tensor_scale_ptr = ( + output_tensor_scale + dest_token_index * output_tensor_scale_stride0 + ) + tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) + tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s) + + +# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py +@torch.no_grad() +def ep_scatter( + recv_x: torch.Tensor, + recv_x_scale: torch.Tensor, + recv_topk: torch.Tensor, + num_recv_tokens_per_expert: torch.Tensor, + expert_start_loc: torch.Tensor, + output_tensor: torch.Tensor, + output_tensor_scale: torch.Tensor, + m_indices: torch.Tensor, + output_index: torch.Tensor, +): + BLOCK_E = 128 # token num of per expert is aligned to 128 + BLOCK_D = 128 # block size of quantization + num_warps = 8 + num_experts = num_recv_tokens_per_expert.shape[0] + hidden_size = recv_x.shape[1] + # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) + grid = num_experts + + assert m_indices.shape[0] % BLOCK_E == 0 + + _fwd_kernel_ep_scatter_1[(grid,)]( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts=num_experts, + num_warps=num_warps, + BLOCK_E=BLOCK_E, + BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), + ) + + grid = min(recv_topk.shape[0], 1024 * 8) + + _fwd_kernel_ep_scatter_2[(grid,)]( + recv_topk.shape[0], + expert_start_loc, + recv_x, + recv_x.stride(0), + recv_x.stride(1), + recv_x_scale, + recv_x_scale.stride(0), + recv_x_scale.stride(1), + recv_topk, + recv_topk.stride(0), + recv_topk.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor_scale, + output_tensor_scale.stride(0), + output_tensor_scale.stride(1), + output_index, + output_index.stride(0), + output_index.stride(1), + topk_num=recv_topk.shape[1], + num_warps=num_warps, + HIDDEN_SIZE=hidden_size, + HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), + SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D, + SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D), + ) + return + + +@triton.jit +def _fwd_kernel_ep_gather( + total_token_num, + input_tensor, + input_tensor_stride0, + input_tensor_stride1, + recv_topk_ids, + recv_topk_ids_stride0, + recv_topk_ids_stride1, + recv_topk_weight, + recv_topk_weight_stride0, + recv_topk_weight_stride1, + input_index, + input_index_stride0, + input_index_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + topk_num: tl.constexpr, + BLOCK_D: tl.constexpr, +): + cur_block = tl.program_id(0) + start_cur_token = tl.program_id(1) + grid_num = tl.num_programs(1) + + for cur_token in range(start_cur_token, total_token_num, grid_num): + off_d = tl.arange(0, BLOCK_D) + accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) + for topk_index in range(0, topk_num): + expert_id = tl.load( + recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index + ) + if expert_id >= 0: + source_token_index = tl.load( + input_index + cur_token * input_index_stride0 + topk_index + ) + acc_weight = tl.load( + recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index + ) + tmp = tl.load( + input_tensor + + source_token_index * input_tensor_stride0 + + cur_block * BLOCK_D + + off_d + ) + accumulator += tmp.to(tl.float32) * acc_weight + + tl.store( + output_tensor + + cur_token * output_tensor_stride0 + + cur_block * BLOCK_D + + off_d, + accumulator.to(output_tensor.dtype.element_ty), + ) + + +@torch.no_grad() +def ep_gather( + input_tensor: torch.Tensor, + recv_topk_ids: torch.Tensor, + recv_topk_weight: torch.Tensor, + input_index: torch.Tensor, + output_tensor: torch.Tensor, +): + BLOCK_D = 1024 # block size of quantization + num_warps = 2 + num_tokens = output_tensor.shape[0] + hidden_size = input_tensor.shape[1] + assert hidden_size % BLOCK_D == 0 + grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) + _fwd_kernel_ep_gather[grid]( + num_tokens, + input_tensor, + input_tensor.stride(0), + input_tensor.stride(1), + recv_topk_ids, + recv_topk_ids.stride(0), + recv_topk_ids.stride(1), + recv_topk_weight, + recv_topk_weight.stride(0), + recv_topk_weight.stride(1), + input_index, + input_index.stride(0), + input_index.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + topk_num=recv_topk_ids.shape[1], + num_warps=num_warps, + BLOCK_D=BLOCK_D, + ) + return + + +# copy from +# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58 +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return ceil_div(x, alignment) * alignment + + +@triton.jit +def _tma_align_input_scale_kernel( + input_scale_ptr, + output_ptr, + m, + k_div_block_size, + input_scale_stride_m, + input_scale_stride_k, + output_stride_m, + output_stride_k, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + grid_m = tl.num_programs(0) + k_offsets = tl.arange(0, BLOCK_SIZE_K) + + for m_base in range(pid_m, m, grid_m): + input_offset = ( + input_scale_ptr + + m_base * input_scale_stride_m + + k_offsets * input_scale_stride_k + ) + input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size) + + output_offset = ( + output_ptr + k_offsets * output_stride_k + m_base * output_stride_m + ) + tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size) + + +# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py +def tma_align_input_scale(input_scale: torch.Tensor): + assert input_scale.dim() == 2 + m, k_div_block_size = input_scale.shape + padd_m = get_tma_aligned_size(m, input_scale.element_size()) + output = torch.empty( + (k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device + ) + + grid_m = min(m, 8192) + BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size) + + _tma_align_input_scale_kernel[(grid_m,)]( + input_scale_ptr=input_scale, + output_ptr=output, + m=m, + k_div_block_size=k_div_block_size, + input_scale_stride_m=input_scale.stride(0), + input_scale_stride_k=input_scale.stride(1), + output_stride_m=output.stride(1), # Note: these are swapped + output_stride_k=output.stride(0), # for column-major + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return output.t()[:m] diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 767196770..b2c76b33a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -4,11 +4,19 @@ from typing import Callable, List, Optional, Tuple import torch from torch.nn import Module +from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM + try: from deep_gemm import ( get_col_major_tma_aligned_tensor, + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked, ) + from sgl_kernel import silu_and_mul + + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) use_deep_gemm = True except ImportError: @@ -20,6 +28,8 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.moe.ep_moe.kernels import ( + ep_gather, + ep_scatter, gelu_and_mul_triton_kernel, grouped_gemm_triton, post_reorder_triton_kernel, @@ -27,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( run_moe_ep_preproess, silu_and_mul_masked_post_quant_fwd, silu_and_mul_triton_kernel, + tma_align_input_scale, ) from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase @@ -842,15 +853,23 @@ class DeepEPMoE(EPMoE): def forward( self, hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, reorder_topk_ids: torch.Tensor, seg_indptr: torch.Tensor, masked_m: torch.Tensor, expected_m: int, + num_recv_tokens_per_expert: List[int], forward_mode: ForwardMode, ): resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: - return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) + if _ENABLE_JIT_DEEPGEMM: + return self.forward_deepgemm_contiguous( + hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert + ) + else: + return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) elif resolved_deepep_mode == DeepEPMode.low_latency: return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m) else: @@ -969,6 +988,106 @@ class DeepEPMoE(EPMoE): ) return down_output + def forward_deepgemm_contiguous( + self, + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], + topk_idx, + topk_weights, + num_recv_tokens_per_expert: List[int], + ): + hidden_states_fp8, hidden_states_scale = hidden_states_fp8 + assert self.quant_method is not None + assert self.activation == "silu" + if num_recv_tokens_per_expert is None: + return hidden_states_fp8.bfloat16() + all_tokens = sum(num_recv_tokens_per_expert) + if all_tokens <= 0: + return hidden_states_fp8.bfloat16() + M, K = hidden_states_fp8.size() + N = self.w13_weight.size(1) + scale_block_size = 128 + + gather_out = torch.empty_like( + hidden_states_fp8, + device=hidden_states_fp8.device, + dtype=torch.bfloat16, + ) + + input_tensor = [ + torch.empty( + (all_tokens, K), + device=hidden_states_fp8.device, + dtype=hidden_states_fp8.dtype, + ), + torch.empty( + (all_tokens, K // 128), + device=hidden_states_fp8.device, + dtype=torch.float32, + ), + ] + m_indices = torch.empty( + all_tokens, device=hidden_states_fp8.device, dtype=torch.int32 + ) + output_index = torch.empty_like(topk_idx) + + num_recv_tokens_per_expert_gpu = torch.tensor( + num_recv_tokens_per_expert, + dtype=torch.int32, + pin_memory=True, + device="cpu", + ).cuda(non_blocking=True) + expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) + + ep_scatter( + hidden_states_fp8, + hidden_states_scale, + topk_idx, + num_recv_tokens_per_expert_gpu, + expert_start_loc, + input_tensor[0], + input_tensor[1], + m_indices, + output_index, + ) + + gateup_output = torch.empty( + (all_tokens, N), + device=hidden_states_fp8.device, + dtype=torch.bfloat16, + ) + input_tensor[1] = tma_align_input_scale(input_tensor[1]) + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + input_tensor, self.w13_weight_fp8, gateup_output, m_indices + ) + down_input = torch.empty( + ( + all_tokens, + N // 2, + ), + device=gateup_output.device, + dtype=torch.bfloat16, + ) + silu_and_mul(gateup_output.view(-1, N), down_input) + down_output = torch.empty( + (all_tokens, K), + device=hidden_states_fp8.device, + dtype=torch.bfloat16, + ) + down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( + down_input, scale_block_size + ) + down_input_scale = tma_align_input_scale(down_input_scale) + m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (down_input_fp8, down_input_scale), + self.w2_weight_fp8, + down_output, + m_indices, + ) + + ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out) + + return gather_out + def forward_deepgemm_masked( self, hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index f273c55cb..34a79f0e8 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,14 +1,19 @@ +from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.utils import DeepEPMode try: from deep_ep import Buffer + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) + use_deepep = True except ImportError: use_deepep = False from enum import IntEnum, auto -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.distributed as dist @@ -78,7 +83,6 @@ class DeepEPBuffer: ), num_rdma_bytes, ) - cls._buffer = Buffer( group, num_nvl_bytes, @@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): topk_weights: torch.Tensor, ): topk_idx = topk_idx.to(torch.int64) + if _ENABLE_JIT_DEEPGEMM: + # TODO hard code 128 block quant,use fp8 communication + hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128) previous_event = Buffer.capture() if self.async_finish else None return hidden_states, topk_idx, topk_weights, previous_event def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): - ( - hidden_states, - topk_idx, - topk_weights, - event, - ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event) - event.current_stream_wait() if self.async_finish else () - if hidden_states.shape[0] > 0: - reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( - hidden_states, topk_idx, fp8_dtype=hidden_states.dtype + if _ENABLE_JIT_DEEPGEMM: + ( + hidden_states, + topk_idx, + topk_weights, + num_recv_tokens_per_expert_list, + event, + ) = self._dispatch_core( + hidden_states, topk_idx, topk_weights, previous_event + ) + event.current_stream_wait() if self.async_finish else () + return ( + hidden_states, + topk_idx, + topk_weights, + None, + num_recv_tokens_per_expert_list, + None, + None, + None, ) else: - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 + ( + hidden_states, + topk_idx, + topk_weights, + num_recv_tokens_per_expert_list, + event, + ) = self._dispatch_core( + hidden_states, topk_idx, topk_weights, previous_event ) - seg_indptr = torch.zeros( - (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + event.current_stream_wait() if self.async_finish else () + if hidden_states.shape[0] > 0: + reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( + hidden_states, topk_idx, fp8_dtype=hidden_states.dtype + ) + else: + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (self.num_experts + 1,), + device=hidden_states.device, + dtype=torch.int64, + ) + + masked_m = expected_m = None + return ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + None, + seg_indptr, + masked_m, + expected_m, ) - masked_m = expected_m = None - - return ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - seg_indptr, - masked_m, - expected_m, - ) - def _dispatch_core( self, - x: torch.Tensor, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], topk_idx: torch.Tensor, topk_weights: torch.Tensor, previous_event, @@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): recv_x, recv_topk_idx, recv_topk_weights, - _, # num_recv_tokens_per_expert_list + num_recv_tokens_per_expert_list, self.handle, event, ) = buffer.dispatch( @@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): previous_event=previous_event, async_finish=self.async_finish, allocate_on_comm_stream=(previous_event is not None) and self.async_finish, + expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1, ) return ( recv_x, recv_topk_idx, recv_topk_weights, + num_recv_tokens_per_expert_list, event, ) @@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): topk_idx: torch.Tensor, topk_weights: torch.Tensor, ): - if hidden_states.shape[0] > 0: - num_tokens = self.src2dst.shape[0] // self.router_topk - output = torch.empty( - (num_tokens, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - deepep_post_reorder_triton_kernel[(num_tokens,)]( - hidden_states, - output, - self.src2dst, - topk_idx, - topk_weights, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) + if _ENABLE_JIT_DEEPGEMM: + output = hidden_states else: - output = torch.zeros( - (0, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) + if hidden_states.shape[0] > 0: + num_tokens = self.src2dst.shape[0] // self.router_topk + output = torch.empty( + (num_tokens, hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + deepep_post_reorder_triton_kernel[(num_tokens,)]( + hidden_states, + output, + self.src2dst, + topk_idx, + topk_weights, + self.router_topk, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + else: + output = torch.zeros( + (0, hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) previous_event = Buffer.capture() if self.async_finish else None return output, previous_event @@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): def _get_buffer(self): DeepEPBuffer.set_dispatch_mode_as_normal() + return DeepEPBuffer.get_deepep_buffer( self.group, self.hidden_size, @@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): topk_idx, topk_weights, reorder_topk_ids, + None, seg_indptr, masked_m, expected_m, @@ -570,7 +611,8 @@ class DeepEPDispatcher: def dispatch(self, *args, **kwargs) -> Tuple: self.dispatch_a(*args, **kwargs) - return self.dispatch_b() + ret = self.dispatch_b() + return ret def dispatch_a( self, @@ -593,7 +635,8 @@ class DeepEPDispatcher: def combine(self, *args, **kwargs) -> Tuple: self.combine_a(*args, **kwargs) - return self.combine_b() + ret = self.combine_b() + return ret def combine_a( self, diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm.py index 6fa7a6dd6..3d6ba6281 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm.py +++ b/python/sglang/srt/layers/quantization/deep_gemm.py @@ -28,6 +28,11 @@ if is_cuda(): if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"): _ENABLE_JIT_DEEPGEMM = True + +def get_enable_jit_deepgemm(): + return _ENABLE_JIT_DEEPGEMM + + logger = logging.getLogger(__name__) _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index d1a0ffa91..cd63e19a5 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -308,8 +308,8 @@ def sglang_per_token_group_quant_fp8( device=x.device, dtype=torch.float32, ) - - sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) + if x.shape[0] > 0: + sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) return x_q, x_s diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 92c5057db..339aaad6b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -357,6 +357,7 @@ class DeepseekV2MoE(nn.Module): topk_idx, topk_weights, reorder_topk_ids, + num_recv_tokens_per_expert, seg_indptr, masked_m, expected_m, @@ -368,10 +369,13 @@ class DeepseekV2MoE(nn.Module): ) final_hidden_states = self.experts( hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, reorder_topk_ids=reorder_topk_ids, seg_indptr=seg_indptr, masked_m=masked_m, expected_m=expected_m, + num_recv_tokens_per_expert=num_recv_tokens_per_expert, forward_mode=forward_mode, ) if self.ep_size > 1: