diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 80fbadd57..e99946869 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -30,13 +30,13 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod +from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, sglang_per_token_group_quant_fp8, sglang_per_token_quant_fp8, ) -from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod +from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -62,8 +62,6 @@ use_flashinfer_trtllm_moe = ( if not (_is_npu or _is_hip): from sgl_kernel import silu_and_mul - from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe - if _use_aiter: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -162,7 +160,7 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts): return tile_tokens_dim -class EPMoE(torch.nn.Module): +class EPMoE(FusedMoE): """ MoE Expert Parallel Impl @@ -184,51 +182,60 @@ class EPMoE(torch.nn.Module): routed_scaling_factor: Optional[float] = None, use_per_token_if_dynamic: bool = True, ): - super().__init__() + super().__init__( + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + top_k=top_k, + layer_id=layer_id, + params_dtype=params_dtype, + quant_config=quant_config, + tp_size=tp_size, + prefix=prefix, + activation=activation, + routed_scaling_factor=routed_scaling_factor, + enable_ep_moe=True, + skip_quant=True, + ) if params_dtype is None: params_dtype = torch.get_default_dtype() - self.tp_size = ( - tp_size if tp_size is not None else get_tensor_model_parallel_world_size() - ) - self.tp_rank = get_tensor_model_parallel_rank() - self.layer_id = layer_id - self.num_experts = num_experts - assert self.num_experts % self.tp_size == 0 - self.num_experts_per_partition, self.expert_map = self.determine_expert_map() - self.start_expert_id = self.tp_rank * self.num_experts_per_partition - self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 + self.num_local_experts, self.expert_map = self.determine_expert_map() + self.start_expert_id = self.ep_rank * self.num_local_experts + self.end_expert_id = self.start_expert_id + self.num_local_experts - 1 - self.top_k = top_k self.intermediate_size = intermediate_size - self.activation = activation - self.routed_scaling_factor = routed_scaling_factor self.use_per_token_if_dynamic = use_per_token_if_dynamic + # TODO(ch-wan): move quant preparation to FusedMoE if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() + self.quant_method: Optional[QuantizeMethodBase] = ( + UnquantizedFusedMoEMethod() + ) self.use_fp8_w8a8 = False self.use_block_quant = False self.block_shape = None self.activation_scheme = None - self.use_w4afp8 = False + self.w13_input_scale = None + self.w2_input_scale = None + self.w13_weight_scale = None + self.w2_weight_scale = None elif isinstance(quant_config, W4AFp8Config): self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod( quant_config ) - self.use_w4afp8 = True self.use_fp8_w8a8 = False self.use_block_quant = False self.fp8_dtype = torch.float8_e4m3fn + self.w13_input_scale = None + self.w2_input_scale = None self.w13_weight_scale = None self.w2_weight_scale = None self.activation_scheme = quant_config.moe_activation_scheme - else: - self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( - quant_config - ) + elif isinstance(quant_config, Fp8Config): + self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config) self.use_fp8_w8a8 = True self.use_block_quant = getattr(self.quant_method, "block_quant", False) self.block_shape = ( @@ -238,11 +245,13 @@ class EPMoE(torch.nn.Module): ) self.fp8_dtype = torch.float8_e4m3fn self.activation_scheme = quant_config.activation_scheme - self.use_w4afp8 = False + else: + raise ValueError(f"Unsupported quant_config: {quant_config}") + self.quant_config = quant_config self.quant_method.create_weights( layer=self, - num_experts_per_partition=self.num_experts_per_partition, + num_experts=self.num_local_experts, hidden_size=hidden_size, intermediate_size=self.intermediate_size, params_dtype=params_dtype, @@ -251,19 +260,6 @@ class EPMoE(torch.nn.Module): self.grouped_gemm_runner = None - self.w13_weight_fp8 = ( - self.w13_weight, - ( - self.w13_weight_scale_inv - if self.use_block_quant - else self.w13_weight_scale - ), - ) - self.w2_weight_fp8 = ( - self.w2_weight, - self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, - ) - # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43 # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank. def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]: @@ -282,8 +278,8 @@ class EPMoE(torch.nn.Module): Contains global_num_experts for experts not assigned to the current rank. Returns None if ep_size is 1. """ - ep_size = self.tp_size - ep_rank = self.tp_rank + ep_size = self.ep_size + ep_rank = self.ep_rank global_num_experts = self.num_experts assert ep_size > 0 @@ -293,7 +289,7 @@ class EPMoE(torch.nn.Module): local_num_experts = global_num_experts // ep_size expert_map = torch.full( - (global_num_experts,), self.num_experts, dtype=torch.int32 + (global_num_experts,), global_num_experts, dtype=torch.int32 ) if ep_rank < (ep_size - 1): expert_map[ @@ -318,6 +314,20 @@ class EPMoE(torch.nn.Module): hidden_states: torch.Tensor, topk_output: TopKOutput, ): + + self.w13_weight_fp8 = ( + self.w13_weight, + ( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + ) + self.w2_weight_fp8 = ( + self.w2_weight, + self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, + ) + assert self.quant_method is not None assert self.activation == "silu" hidden_states_shape = hidden_states.shape @@ -457,7 +467,10 @@ class EPMoE(torch.nn.Module): return output def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - assert self.quant_method is not None + return self.quant_method.apply(self, hidden_states, topk_output) + + def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput): + topk_weights, topk_ids, _ = topk_output hidden_states_shape = hidden_states.shape @@ -470,53 +483,11 @@ class EPMoE(torch.nn.Module): use_per_token_if_dynamic=self.use_per_token_if_dynamic, ) - if self.use_w4afp8: - local_topk_ids = topk_ids - if self.expert_map is not None: - "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where( - self.expert_map[topk_ids] != self.num_experts, - self.expert_map[topk_ids], - self.num_experts, - ) - - output = cutlass_w4a8_moe( - self.start_expert_id, - self.end_expert_id, - self.num_experts, - hidden_states, - self.w13_weight, - self.w2_weight, - self.w13_weight_scale_inv, - self.w2_weight_scale_inv, - topk_weights, - topk_ids, - local_topk_ids, - self.quant_method.a_strides1, - self.quant_method.b_strides1, - self.quant_method.c_strides1, - self.quant_method.a_strides2, - self.quant_method.b_strides2, - self.quant_method.c_strides2, - self.quant_method.s_strides13, - self.quant_method.s_strides2, - self.quant_method.expert_offsets, - self.quant_method.problem_sizes1, - self.quant_method.problem_sizes2, - self.w13_input_scale, - self.w2_input_scale, - ) - return output - - if self.grouped_gemm_runner is None: - self.grouped_gemm_runner = GroupedGemmRunner( - hidden_states.device, - use_flashinfer=False, # TODO: use flashinfer - use_per_token_if_dynamic=self.use_per_token_if_dynamic, - ) + num_experts = self.num_experts reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( - topk_ids, self.num_experts + topk_ids, + num_experts, ) gateup_input = torch.empty( @@ -524,7 +495,7 @@ class EPMoE(torch.nn.Module): device=hidden_states.device, dtype=( self.fp8_dtype - if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant) + if self.use_fp8_w8a8 and not self.use_block_quant else hidden_states.dtype ), ) @@ -535,7 +506,7 @@ class EPMoE(torch.nn.Module): else: max_value = ( torch.max(hidden_states) - .repeat(self.num_experts_per_partition) + .repeat(self.num_local_experts) .to(torch.float32) ) self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max @@ -576,7 +547,7 @@ class EPMoE(torch.nn.Module): seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] weight_indices_cur_rank = torch.arange( 0, - self.num_experts_per_partition, + self.num_local_experts, device=hidden_states_device, dtype=torch.int64, ) @@ -586,17 +557,13 @@ class EPMoE(torch.nn.Module): b=self.w13_weight, c=None, c_dtype=hidden_states_dtype, - batch_size=self.num_experts_per_partition, + batch_size=self.num_local_experts, weight_column_major=True, seg_indptr=seg_indptr_cur_rank, weight_indices=weight_indices_cur_rank, use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w13_input_scale, - scale_b=( - self.w13_weight_scale_inv - if self.use_block_quant - else self.w13_weight_scale - ), + scale_b=self.w13_weight_scale, block_shape=self.block_shape, ) del gateup_input @@ -653,7 +620,7 @@ class EPMoE(torch.nn.Module): down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input) else: self.w2_input_scale = torch.ones( - self.num_experts_per_partition, + self.num_local_experts, dtype=torch.float32, device=hidden_states_device, ) @@ -669,17 +636,13 @@ class EPMoE(torch.nn.Module): a=down_input, b=self.w2_weight, c=down_output, - batch_size=self.num_experts_per_partition, + batch_size=self.num_local_experts, weight_column_major=True, seg_indptr=seg_indptr_cur_rank, weight_indices=weight_indices_cur_rank, use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w2_input_scale, - scale_b=( - self.w2_weight_scale_inv - if self.use_block_quant - else self.w2_weight_scale - ), + scale_b=self.w2_weight_scale, block_shape=self.block_shape, ) del down_input @@ -782,107 +745,14 @@ class EPMoE(torch.nn.Module): return expert_id = expert_id - self.start_expert_id - if shard_id not in ("w1", "w2", "w3"): - raise ValueError( - f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." - ) - - # Special case for fp8 scales. - if "scale" in weight_name: - self._load_fp8_scale( - param.data, - loaded_weight, - weight_name, - shard_id, - expert_id, - ) - return - - # Flashinfer assumes w31 format for w13_weight. Same for the scales. - if use_flashinfer_trtllm_moe: - actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] - else: - actual_shard_id = shard_id - - if actual_shard_id == "w2": - param.data[expert_id] = loaded_weight - elif actual_shard_id == "w1": - param.data[expert_id][: self.intermediate_size, :] = loaded_weight - elif actual_shard_id == "w3": - param.data[expert_id][self.intermediate_size :, :] = loaded_weight - else: - raise ValueError(f"Expected shard_id w1,w2 or w3 but got {actual_shard_id}") - - def _load_fp8_scale( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - ) -> None: - param_data = param.data - - # Input scales can be loaded directly and should be equal. - if "input_scale" in weight_name: - if self.use_w4afp8: - if shard_id == "w1": - param_data[expert_id][0] = loaded_weight - elif shard_id == "w3": - param_data[expert_id][1] = loaded_weight - else: - param_data[expert_id] = loaded_weight - return - - if ( - (shard_id == "w1" or shard_id == "w3") - and param_data[expert_id] != 1 - and (param_data[expert_id] - loaded_weight).abs() > 1e-5 - ): - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param_data[expert_id]} " - f"vs. {loaded_weight}" - ) - param_data[expert_id] = loaded_weight - # Weight scales - elif "weight_scale" in weight_name: - if self.use_block_quant: - if use_flashinfer_trtllm_moe: - actual_shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] - else: - actual_shard_id = shard_id - - block_n, block_k = self.block_shape[0], self.block_shape[1] - - if actual_shard_id == "w1": - param_data[expert_id][ - : (self.intermediate_size + block_n - 1) // block_n, : - ] = loaded_weight - elif actual_shard_id == "w3": - param_data[expert_id][ - (self.intermediate_size + block_n - 1) // block_n :, : - ] = loaded_weight - else: # w2 - param_data[expert_id] = loaded_weight - elif self.use_w4afp8: - if shard_id == "w1": - param_data[expert_id][: self.intermediate_size, :] = loaded_weight - elif shard_id == "w3": - param_data[expert_id][self.intermediate_size :, :] = loaded_weight - else: - param_data[expert_id] = loaded_weight - # If we are in merged column case (gate_up_proj) - else: - if shard_id in ("w1", "w3"): - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight - - # If we are in the row parallel case (down_proj) - else: - param_data[expert_id] = loaded_weight + self._weight_loader_impl( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + return class DeepEPMoE(EPMoE): @@ -932,13 +802,13 @@ class DeepEPMoE(EPMoE): deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM ), f"DeepEP {self.deepep_mode} mode requires deep_gemm" if _use_aiter: - # expert_mask is of size (self.num_experts_per_partition + 1), + # expert_mask is of size (self.num_local_experts + 1), # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid) # for instance, if we have 4 experts on this rank, we would have a expert_mask like: # self.expert_mask = [1, 1, 1, 1, 0] # idx from 0-3 is valid and will be processed, while idx == 4 will be masked out self.expert_mask = torch.zeros( - (self.num_experts_per_partition + 1), + (self.num_local_experts + 1), device=torch.cuda.current_device(), dtype=torch.int, ) @@ -1011,13 +881,13 @@ class DeepEPMoE(EPMoE): if self.activation_scheme == "dynamic" and not self.use_block_quant: max_value = ( torch.max(hidden_states) - .repeat(self.num_experts_per_partition) + .repeat(self.num_local_experts) .to(torch.float32) ) self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max weight_indices_cur_rank = torch.arange( 0, - self.num_experts_per_partition, + self.num_local_experts, device=hidden_states.device, dtype=torch.int64, ) @@ -1029,7 +899,7 @@ class DeepEPMoE(EPMoE): b=self.w13_weight, c=None, c_dtype=hidden_states.dtype, - batch_size=self.num_experts_per_partition, + batch_size=self.num_local_experts, weight_column_major=True, seg_indptr=seg_indptr, weight_indices=weight_indices_cur_rank, @@ -1063,7 +933,7 @@ class DeepEPMoE(EPMoE): ) if self.w2_input_scale is None and not self.use_block_quant: self.w2_input_scale = torch.ones( - self.num_experts_per_partition, + self.num_local_experts, dtype=torch.float32, device=hidden_states_device, ) @@ -1076,7 +946,7 @@ class DeepEPMoE(EPMoE): reorder_topk_ids, self.w2_input_scale, 0, - self.num_experts_per_partition - 1, + self.num_local_experts - 1, BLOCK_SIZE=512, ) else: @@ -1096,7 +966,7 @@ class DeepEPMoE(EPMoE): a=down_input, b=self.w2_weight, c=down_output, - batch_size=self.num_experts_per_partition, + batch_size=self.num_local_experts, weight_column_major=True, seg_indptr=seg_indptr, weight_indices=weight_indices_cur_rank, @@ -1121,9 +991,9 @@ class DeepEPMoE(EPMoE): return hidden_states # in original deepep, idx == -1 meaning invalid and will not be processed. # aiter does not accept -1, we use a expert mask to make these idx invalid - # (idx == num_experts_per_partition) meaning not used in aiter fused_moe + # (idx == num_local_experts) meaning not used in aiter fused_moe topk_idx_copy = topk_idx.to(torch.int32) - topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition + topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts return fused_moe( hidden_states, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 5983a6beb..39368e879 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -77,6 +77,7 @@ class FusedMoE(torch.nn.Module): routed_scaling_factor: Optional[float] = None, enable_flashinfer_cutlass_moe: Optional[bool] = False, enable_ep_moe: Optional[bool] = False, + skip_quant: Optional[bool] = False, ): super().__init__() @@ -99,9 +100,6 @@ class FusedMoE(torch.nn.Module): self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe if enable_ep_moe: - assert ( - self.enable_flashinfer_cutlass_moe - ), "FusedMoE only supports EP with --enable-flashinfer-cutlass-moe" self.ep_size = self.tp_size self.ep_rank = self.tp_rank self.tp_size = 1 @@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module): self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32) # Create a expert map for the local experts assert num_experts % self.ep_size == 0 - self.local_num_experts = num_experts // self.ep_size + self.num_local_experts = num_experts // self.ep_size self.expert_map[ self.ep_rank - * self.local_num_experts : (self.ep_rank + 1) - * self.local_num_experts - ] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu") + * self.num_local_experts : (self.ep_rank + 1) + * self.num_local_experts + ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") else: self.ep_size = 1 self.ep_rank = 0 - self.local_num_experts = num_experts + self.num_local_experts = num_experts self.routed_scaling_factor = routed_scaling_factor assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size @@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module): not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"] ) + if skip_quant: + return + if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( self.use_triton_kernels @@ -149,7 +150,7 @@ class FusedMoE(torch.nn.Module): self.quant_config = quant_config self.quant_method.create_weights( layer=self, - num_experts=self.local_num_experts, + num_experts=self.num_local_experts, hidden_size=hidden_size, # FIXME: figure out which intermediate_size to use intermediate_size=self.intermediate_size_per_partition, @@ -378,6 +379,23 @@ class FusedMoE(torch.nn.Module): if expert_id == -1: return + self._weight_loader_impl( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + + def _weight_loader_impl( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + # TP rank is set to 0 if EP is enabled tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank() @@ -398,6 +416,10 @@ class FusedMoE(torch.nn.Module): f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." ) + # Flashinfer assumes w31 format for w13_weight. Same for the scales. + if getattr(self, "use_flashinfer_trtllm_moe", False): + shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] + WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] # Fetch the dim to shard the parameter/loaded weight # based on the shard id. This will be whatever @@ -605,37 +627,3 @@ class FusedMoE(torch.nn.Module): ("w3", ckpt_up_proj_name), ] ] - - def _load_fp8_scale( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - ) -> None: - param_data = param.data - - # Input scales can be loaded directly and should be equal. - if "input_scale" in weight_name: - if ( - param_data[expert_id] != 1 - and (param_data[expert_id] - loaded_weight).abs() > 1e-5 - ): - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param_data[expert_id]} " - f"vs. {loaded_weight}" - ) - param_data[expert_id] = loaded_weight - # Weight scales - elif "weight_scale" in weight_name: - # If we are in merged column case (gate_up_proj) - if shard_id in ("w1", "w3"): - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) - else: - param_data[expert_id] = loaded_weight diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 6fa3ccc59..ff10b801b 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -172,6 +172,7 @@ class Fp8Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -180,6 +181,8 @@ class Fp8Config(QuantizationConfig): return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): return Fp8MoEMethod(self) + elif isinstance(layer, EPMoE): + return Fp8EPMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -791,11 +794,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): # merged w13 weights and generate a single scaling factor. layer.w13_weight_scale = torch.nn.Parameter( torch.ones( - layer.num_experts, dtype=torch.float32, device=w13_weight.device + layer.num_local_experts, + dtype=torch.float32, + device=w13_weight.device, ), requires_grad=False, ) - for expert in range(layer.num_experts): + for expert in range(layer.num_local_experts): w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) ) @@ -871,7 +876,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.num_experts): + for expert_id in range(layer.num_local_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( @@ -914,7 +919,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.num_experts): + for expert_id in range(layer.num_local_experts): start = 0 max_w13_scale_fp8 = max_w13_scales[expert_id] for shard_id in range(2): @@ -931,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling # optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post - for expert_id in range(layer.num_experts): + for expert_id in range(layer.num_local_experts): layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id] layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id] @@ -979,8 +984,23 @@ class Fp8MoEMethod(FusedMoEMethodBase): no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: + from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + if isinstance(layer, EPMoE): + layer.w13_weight_scale = ( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ) + layer.w2_weight_scale = ( + layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale + ) + return layer.run_moe( + hidden_states=x, + topk_output=topk_output, + ) + if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu @@ -1138,248 +1158,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): return None -class Fp8EPMoEMethod(Fp8MoEMethod): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - self.block_quant = self.quant_config.weight_block_size is not None - - def create_weights( - self, - layer: Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - tp_size = get_tensor_model_parallel_world_size() - if self.block_quant: - block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], - ) - # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. - # Required by column parallel or enabling merged weights - if intermediate_size % block_n != 0: - raise ValueError( - f"The output_size of gate's and up's weight = " - f"{intermediate_size} is not divisible by " - f"weight quantization block_n = {block_n}." - ) - if tp_size > 1: - # Required by row parallel - if intermediate_size % block_k != 0: - raise ValueError( - f"The input_size of down's weight = " - f"{intermediate_size} is not divisible by " - f"weight quantization block_k = {block_k}." - ) - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - hidden_size, - intermediate_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - if self.block_quant: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts_per_partition, - 2 * ((intermediate_size + block_n - 1) // block_n), - (hidden_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts_per_partition, - (hidden_size + block_n - 1) // block_n, - (intermediate_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) - layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) - assert self.quant_config.activation_scheme == "dynamic" - else: - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, 2, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} - if self.block_quant - else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - - w13_input_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - - else: - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer: Module) -> None: - - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - # If rocm, use float8_e4m3fnuz as dtype - fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - - layer.w13_weight_scale = torch.nn.Parameter( - torch.ones( - layer.num_experts_per_partition, - dtype=torch.float32, - device=w13_weight.device, - ), - requires_grad=False, - ) - - for expert in range(layer.num_experts_per_partition): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - if self.quant_config.activation_scheme == "static": - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - layer.w13_weight_scale = torch.nn.Parameter( - torch.max(layer.w13_weight_scale, dim=1).values, - requires_grad=False, - ) - if self.block_quant: - # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_fp8_fnuz: - # activation_scheme: dynamic - w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=layer.w13_weight, - weight_scale=layer.w13_weight_scale_inv, - input_scale=None, - ) - w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=layer.w2_weight, - weight_scale=layer.w2_weight_scale_inv, - input_scale=None, - ) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter( - w13_weight, requires_grad=False - ) - layer.w13_weight_scale_inv = torch.nn.Parameter( - w13_weight_scale, requires_grad=False - ) - layer.w13_input_scale = None - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = torch.nn.Parameter( - w2_weight_scale, requires_grad=False - ) - layer.w2_input_scale = None - if _use_aiter: - layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), - requires_grad=False, - ) - layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), - requires_grad=False, - ) - return - - def apply( - self, - layer: torch.nn.Module, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - ) -> torch.Tensor: - raise NotImplementedError - - class Fp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index ddafcc6f5..121d5b714 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -24,6 +24,7 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: + from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.topk import TopKOutput has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None @@ -194,6 +195,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: + + from sglang.srt.layers.moe.ep_moe.layer import EPMoE + + if isinstance(layer, EPMoE): + return layer.run_moe( + hidden_states=x, + topk_output=topk_output, + ) + return self.forward( x=x, layer=layer, @@ -354,69 +364,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): raise NotImplementedError("The TPU backend currently does not support MoE.") forward_native = forward_cpu - - -class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): - - def create_weights( - self, - layer: torch.nn.Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - # down_proj (row parallel) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - hidden_size, - intermediate_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # scale - layer.register_parameter("w13_input_scale", None) - layer.register_parameter("w13_weight_scale", None) - - ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) - - w2_input_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - - w2_weight_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - ) -> torch.Tensor: - raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index 1c9dc5d33..0a2f555c8 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from torch.nn import Module @@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs +if TYPE_CHECKING: + from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput + ACTIVATION_SCHEMES = ["static", "dynamic"] logger = logging.getLogger(__name__) @@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() return Fp8LinearMethod(self) - elif isinstance(layer, FusedMoE): + elif isinstance(layer, EPMoE): return W4AFp8MoEMethod(self) return None @@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): def create_weights( self, - layer: Module, - num_experts_per_partition: int, + layer: EPMoE, + num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, @@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.empty( - num_experts_per_partition, + num_experts, intermediate_size * 2, hidden_size // 2, dtype=torch.int8, @@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): # down_proj (row parallel) w2_weight = torch.nn.Parameter( torch.empty( - num_experts_per_partition, + num_experts, hidden_size, intermediate_size // 2, dtype=torch.int8, @@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): w13_weight_scale = torch.nn.Parameter( torch.zeros( - num_experts_per_partition, + num_experts, 2 * intermediate_size, hidden_size // self.quant_config.group_size, dtype=torch.float32, @@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): w2_weight_scale = torch.nn.Parameter( torch.zeros( - num_experts_per_partition, + num_experts, hidden_size, intermediate_size // self.quant_config.group_size, dtype=torch.float32, @@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): # Input scales w13_input_scale = torch.nn.Parameter( - torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16), + torch.ones((num_experts, 2), dtype=torch.bfloat16), requires_grad=False, ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.bfloat16), + torch.ones(num_experts, dtype=torch.bfloat16), requires_grad=False, ) layer.register_parameter("w2_input_scale", w2_input_scale) @@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): device = layer.w13_weight.device self.a_strides1 = torch.full( - (num_experts_per_partition, 3), + (num_experts, 3), hidden_size, device=device, dtype=torch.int64, ) self.c_strides1 = torch.full( - (num_experts_per_partition, 3), + (num_experts, 3), 2 * intermediate_size, device=device, dtype=torch.int64, ) self.a_strides2 = torch.full( - (num_experts_per_partition, 3), + (num_experts, 3), intermediate_size, device=device, dtype=torch.int64, ) self.c_strides2 = torch.full( - (num_experts_per_partition, 3), + (num_experts, 3), hidden_size, device=device, dtype=torch.int64, @@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): self.s_strides2 = self.c_strides2 self.expert_offsets = torch.empty( - (num_experts_per_partition + 1), dtype=torch.int32, device=device + (num_experts + 1), dtype=torch.int32, device=device ) self.problem_sizes1 = torch.empty( - (num_experts_per_partition, 3), dtype=torch.int32, device=device + (num_experts, 3), dtype=torch.int32, device=device ) self.problem_sizes2 = torch.empty( - (num_experts_per_partition, 3), dtype=torch.int32, device=device + (num_experts, 3), dtype=torch.int32, device=device ) return @@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): [w2_input_scale_max], dtype=dtype, device=device ) layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False) + + def apply( + self, + layer: EPMoE, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ) -> torch.Tensor: + + # TODO(ch-wan): move it out of this class + from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe + + topk_ids, topk_weights, _ = topk_output + local_topk_ids = topk_ids + if layer.expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where( + layer.expert_map[topk_ids] != layer.num_experts, + layer.expert_map[topk_ids], + layer.num_experts, + ) + + return cutlass_w4a8_moe( + layer.start_expert_id, + layer.end_expert_id, + layer.num_experts, + hidden_states, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale_inv, + layer.w2_weight_scale_inv, + topk_weights, + topk_ids, + local_topk_ids, + self.a_strides1, + self.b_strides1, + self.c_strides1, + self.a_strides2, + self.b_strides2, + self.c_strides2, + self.s_strides13, + self.s_strides2, + self.expert_offsets, + self.problem_sizes1, + self.problem_sizes2, + layer.w13_input_scale, + layer.w2_input_scale, + )