diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index c9a20d276..e74df36da 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -86,79 +86,6 @@ if use_flashinfer_trtllm_moe: logger = logging.getLogger(__name__) -class GroupedGemmRunner(torch.nn.Module): - flashinfer_gemm_warpper = None - - def __init__( - self, - device, - use_flashinfer: bool = False, - use_per_token_if_dynamic: bool = True, - ): - super().__init__() - self.device = device - self.use_flashinfer = use_flashinfer - self.use_per_token_if_dynamic = use_per_token_if_dynamic - if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None: - GroupedGemmRunner._init_flashinfer_wrapper(device) - - @classmethod - def _init_flashinfer_wrapper(cls, device): - from flashinfer import SegmentGEMMWrapper - - workspace_buffer = torch.empty( - 128 * 1024 * 1024, dtype=torch.int8, device=device - ) - cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer) - - # c = a * b - def forward( - self, - a: torch.Tensor, - b: torch.Tensor, - c: torch.Tensor, - batch_size: int, - weight_column_major: bool, - seg_indptr: Optional[torch.Tensor] = None, - weight_indices: Optional[torch.Tensor] = None, - use_fp8_w8a8: bool = False, - scale_a: torch.Tensor = None, - scale_b: torch.Tensor = None, - block_shape: Optional[List[int]] = None, - c_dtype=None, - ): - if self.use_flashinfer: - # TODO: flashinfer - assert False - assert GroupedGemmRunner.flashinfer_gemm_warpper is not None - c = GroupedGemmRunner.flashinfer_gemm_warpper.run( - x=a, - weights=b, - batch_size=batch_size, - weight_column_major=weight_column_major, - seg_indptr=seg_indptr, - weight_indices=weight_indices, - ) - else: - assert weight_column_major == True - c = grouped_gemm_triton( - a, - b, - c, - batch_size, - weight_column_major, - seg_indptr, - weight_indices, - use_fp8_w8a8, - scale_a, - scale_b, - block_shape=block_shape, - c_dtype=c_dtype, - use_per_token_if_dynamic=self.use_per_token_if_dynamic, - ) - return c - - def _get_tile_tokens_dim(num_tokens, top_k, num_experts): # Guess tokens per expert assuming perfect expert distribution first. num_tokens_per_expert = (num_tokens * top_k) // num_experts @@ -190,135 +117,50 @@ class EPMoE(FusedMoE): prefix: str = "", activation: str = "silu", routed_scaling_factor: Optional[float] = None, - use_per_token_if_dynamic: bool = True, ): super().__init__( num_experts=num_experts, hidden_size=hidden_size, intermediate_size=intermediate_size, - top_k=top_k, num_fused_shared_experts=num_fused_shared_experts, layer_id=layer_id, + top_k=top_k, params_dtype=params_dtype, quant_config=quant_config, tp_size=tp_size, prefix=prefix, activation=activation, + # apply_router_weight_on_input=apply_router_weight_on_input, routed_scaling_factor=routed_scaling_factor, enable_ep_moe=True, - skip_quant=True, ) - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - self.layer_id = layer_id - self.num_local_experts, self.expert_map = self.determine_expert_map() self.start_expert_id = self.ep_rank * self.num_local_experts self.end_expert_id = self.start_expert_id + self.num_local_experts - 1 self.intermediate_size = intermediate_size - self.use_per_token_if_dynamic = use_per_token_if_dynamic - # TODO(ch-wan): move quant preparation to FusedMoE - if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod() - ) - self.use_fp8_w8a8 = False - self.use_block_quant = False - self.block_shape = None - self.activation_scheme = None - self.w13_input_scale = None - self.w2_input_scale = None - self.w13_weight_scale = None - self.w2_weight_scale = None - elif isinstance(quant_config, W4AFp8Config): - self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod( - quant_config - ) - self.use_fp8_w8a8 = False - self.use_block_quant = False - self.fp8_dtype = torch.float8_e4m3fn - self.w13_input_scale = None - self.w2_input_scale = None - self.w13_weight_scale = None - self.w2_weight_scale = None - self.activation_scheme = quant_config.moe_activation_scheme - elif isinstance(quant_config, Fp8Config): - self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config) - self.use_fp8_w8a8 = True + if isinstance(quant_config, Fp8Config): self.use_block_quant = getattr(self.quant_method, "block_quant", False) self.block_shape = ( self.quant_method.quant_config.weight_block_size if self.use_block_quant else None ) + self.use_fp8_w8a8 = True self.fp8_dtype = torch.float8_e4m3fn self.activation_scheme = quant_config.activation_scheme else: - raise ValueError(f"Unsupported quant_config: {quant_config}") - - self.quant_config = quant_config - self.quant_method.create_weights( - layer=self, - num_experts=self.num_local_experts, - hidden_size=hidden_size, - intermediate_size=self.intermediate_size, - params_dtype=params_dtype, - weight_loader=self.weight_loader, - ) - - self.grouped_gemm_runner = None - - # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43 - # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank. - def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]: - """ - Calculates how many experts should be assigned to each rank for EP and - creates a mapping from global to local expert index. Experts are - distributed evenly across ranks. Any remaining are assigned to the - last rank. - - Returns: - Tuple[int, Optional[torch.Tensor]]: A tuple containing: - - local_num_experts (int): The number of experts assigned - to the current rank. - - expert_map (Optional[torch.Tensor]): A tensor of shape - (global_num_experts,) mapping from global to local index. - Contains global_num_experts for experts not assigned to the current rank. - Returns None if ep_size is 1. - """ - ep_size = self.ep_size - ep_rank = self.ep_rank - global_num_experts = self.num_experts - - assert ep_size > 0 - if ep_size == 1: - return (global_num_experts, None) - - local_num_experts = global_num_experts // ep_size - - expert_map = torch.full( - (global_num_experts,), global_num_experts, dtype=torch.int32 - ) - if ep_rank < (ep_size - 1): - expert_map[ - ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts - ] = torch.arange(0, local_num_experts, dtype=torch.int32) - else: - local_num_experts = global_num_experts - ep_rank * local_num_experts - - expert_map[-local_num_experts:] = torch.arange( - 0, local_num_experts, dtype=torch.int32 - ) - return (local_num_experts, expert_map) + self.use_fp8_w8a8 = False + self.use_block_quant = False + self.block_shape = None + self.activation_scheme = None def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: return self.forward_deepgemm(hidden_states, topk_output) else: - return self.forward_normal(hidden_states, topk_output) + return super().forward(hidden_states, topk_output) def forward_deepgemm( self, @@ -477,303 +319,6 @@ class EPMoE(FusedMoE): ) return output - def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - return self.quant_method.apply(self, hidden_states, topk_output) - - def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - - topk_weights, topk_ids, _ = topk_output - - hidden_states_shape = hidden_states.shape - hidden_states_dtype = hidden_states.dtype - hidden_states_device = hidden_states.device - if self.grouped_gemm_runner is None: - self.grouped_gemm_runner = GroupedGemmRunner( - hidden_states.device, - use_flashinfer=False, # TODO: use flashinfer - use_per_token_if_dynamic=self.use_per_token_if_dynamic, - ) - - num_experts = self.num_experts - - reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( - topk_ids, - num_experts, - ) - - gateup_input = torch.empty( - (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), - device=hidden_states.device, - dtype=( - self.fp8_dtype - if self.use_fp8_w8a8 and not self.use_block_quant - else hidden_states.dtype - ), - ) - if self.activation_scheme == "dynamic" and not self.use_block_quant: - if self.use_per_token_if_dynamic: - max_value = torch.max(hidden_states, dim=1).values.to(torch.float32) - self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max - else: - max_value = ( - torch.max(hidden_states) - .repeat(self.num_local_experts) - .to(torch.float32) - ) - self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max - - # PreReorder - pre_reorder_triton_kernel[(hidden_states.shape[0],)]( - hidden_states, - gateup_input, - src2dst, - topk_ids, - self.w13_input_scale, - self.start_expert_id, - self.end_expert_id, - self.top_k, - hidden_states.shape[1], - BLOCK_SIZE=512, - use_per_token_if_dynamic=self.use_per_token_if_dynamic, - ) - dispose_tensor(hidden_states) - - if ( - self.activation_scheme == "dynamic" - and not self.use_block_quant - and self.use_per_token_if_dynamic - ): - scale = torch.empty( - hidden_states_shape[0] * self.top_k, - device=hidden_states_device, - dtype=torch.float32, - ) - scale[src2dst] = ( - self.w13_input_scale.unsqueeze(1) - .expand(hidden_states_shape[0], self.top_k) - .reshape(-1) - ) - self.w13_input_scale = scale - - seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] - weight_indices_cur_rank = torch.arange( - 0, - self.num_local_experts, - device=hidden_states_device, - dtype=torch.int64, - ) - # GroupGemm-0 - gateup_output = self.grouped_gemm_runner( - a=gateup_input, - b=self.w13_weight, - c=None, - c_dtype=hidden_states_dtype, - batch_size=self.num_local_experts, - weight_column_major=True, - seg_indptr=seg_indptr_cur_rank, - weight_indices=weight_indices_cur_rank, - use_fp8_w8a8=self.use_fp8_w8a8, - scale_a=self.w13_input_scale, - scale_b=self.w13_weight_scale, - block_shape=self.block_shape, - ) - del gateup_input - - # Act - if self.activation_scheme == "dynamic" and not self.use_block_quant: - self.w2_input_scale = None - down_input = torch.empty( - gateup_output.shape[0], - gateup_output.shape[1] // 2, - device=gateup_output.device, - dtype=hidden_states_dtype, - ) - else: - down_input = torch.empty( - gateup_output.shape[0], - gateup_output.shape[1] // 2, - device=gateup_output.device, - dtype=( - self.fp8_dtype - if (self.use_fp8_w8a8 and not self.use_block_quant) - else hidden_states_dtype - ), - ) - - if self.activation == "silu": - silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( - gateup_output, - down_input, - gateup_output.shape[1], - reorder_topk_ids, - self.w2_input_scale, - self.start_expert_id, - self.end_expert_id, - BLOCK_SIZE=512, - ) - elif self.activation == "gelu": - gelu_and_mul_triton_kernel[(gateup_output.shape[0],)]( - gateup_output, - down_input, - gateup_output.shape[1], - reorder_topk_ids, - self.w2_input_scale, - self.start_expert_id, - self.end_expert_id, - BLOCK_SIZE=512, - ) - else: - raise ValueError(f"Unsupported activation: {self.activation=}") - del gateup_output - - if self.activation_scheme == "dynamic" and not self.use_block_quant: - if self.use_per_token_if_dynamic: - down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input) - else: - self.w2_input_scale = torch.ones( - self.num_local_experts, - dtype=torch.float32, - device=hidden_states_device, - ) - - # GroupGemm-1 - down_output = torch.empty( - down_input.shape[0], - self.w2_weight.shape[1], - device=hidden_states_device, - dtype=hidden_states_dtype, - ) - down_output = self.grouped_gemm_runner( - a=down_input, - b=self.w2_weight, - c=down_output, - batch_size=self.num_local_experts, - weight_column_major=True, - seg_indptr=seg_indptr_cur_rank, - weight_indices=weight_indices_cur_rank, - use_fp8_w8a8=self.use_fp8_w8a8, - scale_a=self.w2_input_scale, - scale_b=self.w2_weight_scale, - block_shape=self.block_shape, - ) - del down_input - - # PostReorder - output = torch.empty( - hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device - ) - post_reorder_triton_kernel[(hidden_states_shape[0],)]( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - self.start_expert_id, - self.end_expert_id, - self.top_k, - hidden_states_shape[1], - 0, - BLOCK_SIZE=512, - ) - return output - - @classmethod - def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, - ) -> List[Tuple[str, str, int, str]]: - return [ - # (param_name, weight_name, expert_id, shard_id) - ( - ( - "experts.w13_" - if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] - else "experts.w2_" - ), - f"experts.{expert_id}.{weight_name}.", - expert_id, - shard_id, - ) - for expert_id in range(num_experts) - for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] - - @classmethod - def make_expert_input_scale_params_mapping( - cls, - num_experts: int, - ) -> List[Tuple[str, str, int, str]]: - # (param_name, weight_name, expert_id, shard_id) - return [ - ( - "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_", - f"experts.{expert_id}.{shard_id}.", - expert_id, - shard_id, - ) - for expert_id in range(num_experts) - for shard_id in ["w1", "w2", "w3"] - ] - - def weight_loader( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - ) -> None: - global_expert_location_metadata = get_global_expert_location_metadata() - if global_expert_location_metadata is None: - self._weight_loader_impl( - param=param, - loaded_weight=loaded_weight, - weight_name=weight_name, - shard_id=shard_id, - expert_id=expert_id, - ) - return - - physical_expert_ids = global_expert_location_metadata.logical_to_all_physical( - self.layer_id, expert_id - ) - for physical_expert_id in physical_expert_ids: - self._weight_loader_physical( - param=param, - loaded_weight=loaded_weight, - weight_name=weight_name, - shard_id=shard_id, - expert_id=physical_expert_id, - ) - - def _weight_loader_physical( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - ) -> None: - if expert_id < self.start_expert_id or expert_id > self.end_expert_id: - return - expert_id = expert_id - self.start_expert_id - - self._weight_loader_impl( - param=param, - loaded_weight=loaded_weight, - weight_name=weight_name, - shard_id=shard_id, - expert_id=expert_id, - ) - return - class DeepEPMoE(EPMoE): """ @@ -905,14 +450,15 @@ class DeepEPMoE(EPMoE): # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel return self.forward_aiter(dispatch_output) if dispatch_output.format.is_deepep_normal(): - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: - return self.forward_deepgemm_contiguous(dispatch_output) - else: - return self.forward_normal(dispatch_output) + assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 + return self.forward_deepgemm_contiguous(dispatch_output) elif dispatch_output.format.is_deepep_ll(): + assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_masked(dispatch_output) else: - raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") + raise ValueError( + f"Dispatch output format {dispatch_output.format} is not supported" + ) def combine( self, @@ -928,185 +474,6 @@ class DeepEPMoE(EPMoE): forward_batch=forward_batch, ) - def _prepare_for_normal( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - ): - from sglang.srt.layers.moe.ep_moe.kernels import ( - deepep_permute_triton_kernel, - deepep_run_moe_deep_preprocess, - ) - - if hidden_states.shape[0] == 0: - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 - ) - seg_indptr = torch.zeros( - (self.num_experts + 1,), - device=hidden_states.device, - dtype=torch.int64, - ) - return reorder_topk_ids, seg_indptr, hidden_states - else: - if _use_aiter: - # skip permutation here as aiter fused_moe has fused inside - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 - ) - seg_indptr = torch.zeros( - (self.num_experts + 1,), - device=hidden_states.device, - dtype=torch.int64, - ) - return reorder_topk_ids, seg_indptr, hidden_states - - reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( - topk_idx, self.num_experts - ) - num_total_tokens = reorder_topk_ids.numel() - gateup_input = torch.empty( - (int(num_total_tokens), hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - # PreReorder - deepep_permute_triton_kernel[(hidden_states.shape[0],)]( - hidden_states, - gateup_input, - self.src2dst, - topk_idx, - None, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - return reorder_topk_ids, seg_indptr, gateup_input - - def forward_normal( - self, - dispatch_output: DeepEPNormalOutput, - ): - hidden_states, topk_idx = ( - dispatch_output.hidden_states, - dispatch_output.topk_idx, - ) - reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal( - hidden_states, topk_idx - ) - hidden_states_dtype = hidden_states.dtype - hidden_states_device = hidden_states.device - - assert self.quant_method is not None - assert self.activation == "silu" - if self.grouped_gemm_runner is None: - self.grouped_gemm_runner = GroupedGemmRunner( - hidden_states.device, use_flashinfer=False # TODO: use flashinfer - ) - - if self.activation_scheme == "dynamic" and not self.use_block_quant: - max_value = ( - torch.max(hidden_states) - .repeat(self.num_local_experts) - .to(torch.float32) - ) - self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max - weight_indices_cur_rank = torch.arange( - 0, - self.num_local_experts, - device=hidden_states.device, - dtype=torch.int64, - ) - - # GroupGemm-0 - if hidden_states.shape[0] > 0: - gateup_output = self.grouped_gemm_runner( - a=hidden_states, - b=self.w13_weight, - c=None, - c_dtype=hidden_states.dtype, - batch_size=self.num_local_experts, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices_cur_rank, - use_fp8_w8a8=self.use_fp8_w8a8, - scale_a=self.w13_input_scale, - scale_b=( - self.w13_weight_scale_inv - if self.use_block_quant - else self.w13_weight_scale - ), - block_shape=self.block_shape, - ) - else: - gateup_output = torch.empty( - hidden_states.shape[0], - self.w13_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - # Act - down_input = torch.empty( - gateup_output.shape[0], - gateup_output.shape[1] // 2, - device=gateup_output.device, - dtype=( - self.fp8_dtype - if (self.use_fp8_w8a8 and not self.use_block_quant) - else hidden_states_dtype - ), - ) - if self.w2_input_scale is None and not self.use_block_quant: - self.w2_input_scale = torch.ones( - self.num_local_experts, - dtype=torch.float32, - device=hidden_states_device, - ) - - if self.activation == "silu": - silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( - gateup_output, - down_input, - gateup_output.shape[1], - reorder_topk_ids, - self.w2_input_scale, - 0, - self.num_local_experts - 1, - BLOCK_SIZE=512, - ) - else: - raise ValueError(f"Unsupported activation: {self.activation=}") - - del gateup_output - - # GroupGemm-1 - down_output = torch.empty( - down_input.shape[0], - self.w2_weight.shape[1], - device=hidden_states_device, - dtype=hidden_states_dtype, - ) - if down_input.shape[0] > 0: - down_output = self.grouped_gemm_runner( - a=down_input, - b=self.w2_weight, - c=down_output, - batch_size=self.num_local_experts, - weight_column_major=True, - seg_indptr=seg_indptr, - weight_indices=weight_indices_cur_rank, - use_fp8_w8a8=self.use_fp8_w8a8, - scale_a=self.w2_input_scale, - scale_b=( - self.w2_weight_scale_inv - if self.use_block_quant - else self.w2_weight_scale - ), - block_shape=self.block_shape, - ) - return down_output - def forward_aiter( self, dispatch_output: DeepEPNormalOutput, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index cd027d113..d2c65d973 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -413,18 +413,37 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = offs_token.to(tl.int64) token_mask = offs_token < num_valid_tokens - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) - off_experts = tl.load(expert_ids_ptr + pid_m) b_ptrs = ( b_ptr + off_experts * stride_be @@ -497,7 +516,6 @@ def fused_moe_kernel( accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: - # fix out of shared memory issue if use_fp8_w8a8: accumulator = tl.dot(a, b, acc=accumulator) else: diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 316bced90..81e35d002 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -12,7 +12,7 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_reduce, ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata -from sglang.srt.layers.moe.topk import TopKOutput +from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -79,7 +79,6 @@ class FusedMoE(torch.nn.Module): routed_scaling_factor: Optional[float] = None, enable_flashinfer_cutlass_moe: Optional[bool] = False, enable_ep_moe: Optional[bool] = False, - skip_quant: Optional[bool] = False, ): super().__init__() @@ -95,7 +94,8 @@ class FusedMoE(torch.nn.Module): self.tp_rank = get_tensor_model_parallel_rank() self.num_experts = num_experts self.num_fused_shared_experts = num_fused_shared_experts - self.expert_map = None + self.expert_map_cpu = None + self.expert_map_gpu = None if enable_flashinfer_cutlass_moe and quant_config is None: logger.warning("Disable flashinfer MoE when quantization config is None.") @@ -104,20 +104,22 @@ class FusedMoE(torch.nn.Module): self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe if enable_ep_moe: + # TODO(ch-wan): support shared experts fusion self.ep_size = self.tp_size self.ep_rank = self.tp_rank self.tp_size = 1 self.tp_rank = 0 # Create a tensor of size num_experts filled with -1 - self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32) + self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32) # Create a expert map for the local experts assert num_experts % self.ep_size == 0 self.num_local_experts = num_experts // self.ep_size - self.expert_map[ + self.expert_map_cpu[ self.ep_rank * self.num_local_experts : (self.ep_rank + 1) * self.num_local_experts ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") + self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") else: self.ep_size = 1 self.ep_rank = 0 @@ -136,9 +138,6 @@ class FusedMoE(torch.nn.Module): not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"] ) - if skip_quant: - return - if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( self.use_triton_kernels @@ -367,9 +366,9 @@ class FusedMoE(torch.nn.Module): expert_data.copy_(loaded_weight) def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: - if self.expert_map is None: + if self.expert_map_cpu is None: return expert_id - return self.expert_map[expert_id].item() + return self.expert_map_cpu[expert_id].item() def weight_loader( self, @@ -421,7 +420,6 @@ class FusedMoE(torch.nn.Module): expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: return - self._weight_loader_impl( param=param, loaded_weight=loaded_weight, @@ -614,9 +612,14 @@ class FusedMoE(torch.nn.Module): ) return - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): + def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput): assert self.quant_method is not None + if self.expert_map_gpu is not None: + topk_output = topk_output._replace( + topk_ids=self.expert_map_gpu[topk_output.topk_ids] + ) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -670,3 +673,20 @@ class FusedMoE(torch.nn.Module): ("w3", ckpt_up_proj_name), ] ] + + @classmethod + def make_expert_input_scale_params_mapping( + cls, + num_experts: int, + ) -> List[Tuple[str, str, int, str]]: + # (param_name, weight_name, expert_id, shard_id) + return [ + ( + "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_", + f"experts.{expert_id}.{shard_id}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id in ["w1", "w2", "w3"] + ] diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index ff10b801b..49a3af57f 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase - from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig): return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): return Fp8MoEMethod(self) - elif isinstance(layer, EPMoE): - return Fp8EPMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - if isinstance(layer, EPMoE): - layer.w13_weight_scale = ( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ) - layer.w2_weight_scale = ( - layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale - ) - return layer.run_moe( - hidden_states=x, - topk_output=topk_output, - ) - if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index a307fcc11..38b889695 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - from sglang.srt.layers.moe.ep_moe.layer import EPMoE - - if isinstance(layer, EPMoE): - return layer.run_moe( - hidden_states=x, - topk_output=topk_output, - ) - return self.forward( x=x, layer=layer, diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index 0a2f555c8..8619c042b 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): layer: EPMoE, hidden_states: torch.Tensor, topk_output: TopKOutput, + **kwargs, ) -> torch.Tensor: # TODO(ch-wan): move it out of this class