diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 25104bd80..85ec5bd80 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -359,7 +359,17 @@ class ModelConfig: if hf_api.file_exists(self.model_path, "hf_quant_config.json"): quant_cfg = modelopt_quant_config elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")): - quant_cfg = modelopt_quant_config + quant_config_file = os.path.join( + self.model_path, "hf_quant_config.json" + ) + with open(quant_config_file) as f: + quant_config_dict = json.load(f) + json_quant_configs = quant_config_dict["quantization"] + quant_algo = json_quant_configs.get("quant_algo", None) + if quant_algo == "MIXED_PRECISION": + quant_cfg = {"quant_method": "w4afp8"} + else: + quant_cfg = modelopt_quant_config return quant_cfg # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py @@ -389,6 +399,7 @@ class ModelConfig: "w8a8_fp8", "moe_wna16", "qoq", + "w4afp8", ] compatible_quantization_methods = { "modelopt_fp4": ["modelopt"], diff --git a/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py new file mode 100644 index 000000000..0a2b44bd1 --- /dev/null +++ b/python/sglang/srt/layers/moe/cutlass_w4a8_moe.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Cutlass W4A8 MoE kernel.""" +from typing import Optional + +import torch +from sgl_kernel import ( + cutlass_w4a8_moe_mm, + get_cutlass_w4a8_moe_mm_data, + sgl_per_tensor_quant_fp8, + silu_and_mul, +) + +from sglang.srt.layers.moe.ep_moe.kernels import ( + post_reorder_triton_kernel, + pre_reorder_triton_kernel_for_cutlass_moe, + run_cutlass_moe_ep_preproess, +) + + +def cutlass_w4a8_moe( + start_expert_id: int, + end_expert_id: int, + total_num_experts: int, + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids_: torch.Tensor, + local_topk_ids: torch.Tensor, + a_strides1: torch.Tensor, + b_strides1: torch.Tensor, + c_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides2: torch.Tensor, + c_strides2: torch.Tensor, + s_strides13: torch.Tensor, + s_strides2: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + """ + This function computes a w4a8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of int4-quantized expert weights. + Shape: [num_experts, N * 2, K // 2] + (the weights are passed transposed and int4-packed) + - w2_q (torch.Tensor): The second set of int4-quantized expert weights. + Shape: [num_experts, K, N // 2] + (the weights are passed transposed and int4-packed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts, K // 512, N * 8] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts, N // 512, K * 4] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - a_strides1 (torch.Tensor): The input strides of the first grouped gemm. + - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - a_strides2 (torch.Tensor): The input strides of the second grouped gemm. + - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm. + - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [1, K] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [1, N] + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. + + Returns: + - torch.Tensor: The fp8 output tensor after applying the MoE layer. + """ + assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" + assert w1_q.dtype == torch.int8 + assert w2_q.dtype == torch.int8 + assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1" + assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" + assert ( + w1_scale.shape[1] == w1_q.shape[2] * 2 / 512 + and w1_scale.shape[2] == w1_q.shape[1] * 4 + ), "W1 scale shape mismatch" + assert ( + w2_scale.shape[1] == w2_q.shape[2] * 2 / 512 + and w2_scale.shape[2] == w2_q.shape[1] * 4 + ), "W2 scale shape mismatch" + + assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch" + assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch" + assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch" + assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch" + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(2) * 2 # w1_q is transposed and packed + n = w2_q.size(2) * 2 # w2_q is transposed and packed + topk = topk_ids_.size(1) + + if apply_router_weight_on_input: + assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1" + + device = a.device + + _, src2dst, _ = run_cutlass_moe_ep_preproess( + local_topk_ids, + num_experts, + ) + + gateup_input = torch.empty( + (m * topk, k), + device=device, + dtype=torch.float8_e4m3fn, + ) + + pre_reorder_triton_kernel_for_cutlass_moe[(m,)]( + a, + gateup_input, + src2dst, + local_topk_ids, + a1_scale, + total_num_experts, + topk, + k, + BLOCK_SIZE=512, + ) + + # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel, + # they are kept to allow for a quick switch of the permutation logic + # from the current triton kernel implementation to the cutlass-based one if needed. + a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) + get_cutlass_w4a8_moe_mm_data( + local_topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + num_experts, + n, + k, + ) + + c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half) + c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half) + + cutlass_w4a8_moe_mm( + c1, + gateup_input, + w1_q, + a1_scale.float(), + w1_scale, + expert_offsets[:-1], + problem_sizes1, + a_strides1, + b_strides1, + c_strides1, + s_strides13, + 128, + topk, + ) + + intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half) + silu_and_mul(c1, intermediate) + + intermediate_q = torch.empty( + intermediate.shape, dtype=torch.float8_e4m3fn, device=device + ) + sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True) + + cutlass_w4a8_moe_mm( + c2, + intermediate_q, + w2_q, + a2_scale.float(), + w2_scale, + expert_offsets[:-1], + problem_sizes2, + a_strides2, + b_strides2, + c_strides2, + s_strides2, + 128, + topk, + ) + + output = torch.empty_like(a) + post_reorder_triton_kernel[(m,)]( + c2, + output, + src2dst, + topk_ids_, + topk_weights, + start_expert_id, + end_expert_id, + topk, + k, + 0, + BLOCK_SIZE=512, + ) + return output diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 01bdf226c..1d661931c 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) + seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) @@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): compute_src2dst_triton_kernel[grid]( reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE ) + return reorder_topk_ids, src2dst, seg_indptr +def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int): + reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True) + + seg_indptr = torch.zeros( + local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64 + ) + src2dst = torch.empty( + local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32 + ) + + BLOCK_SIZE = 512 + grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),) + compute_src2dst_triton_kernel[grid]( + reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE + ) + + return reorder_topk_ids, src2dst, seg_indptr + + +@triton.jit +def pre_reorder_triton_kernel_for_cutlass_moe( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + num_experts, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + OutDtype = gateup_input_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr = input_ptr + src_idx * hidden_size + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id != num_experts: + if a1_scales_ptr is not None: + scale = 1.0 / tl.load(a1_scales_ptr) + else: + scale = 1.0 + + dst_idx = tl.load(src2dst_ptr + idx) + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) + out_data = (in_data * scale).to(OutDtype) + tl.store(dst_ptr + offset, out_data, mask=mask) + + @triton.jit def pre_reorder_triton_kernel( input_ptr, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 848c4581c..568337fe9 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -12,6 +12,7 @@ from sglang.srt.distributed import ( ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo +from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.moe.ep_moe.kernels import ( ep_gather, ep_scatter, @@ -20,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( moe_ep_deepgemm_preprocess, post_reorder_triton_kernel, pre_reorder_triton_kernel, + pre_reorder_triton_kernel_for_cutlass_moe, + run_cutlass_moe_ep_preproess, run_moe_ep_preproess, silu_and_mul_masked_post_quant_fwd, silu_and_mul_triton_kernel, @@ -41,6 +44,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_quant_fp8, ) from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import ( @@ -191,7 +195,7 @@ class EPMoE(torch.nn.Module): num_fused_shared_experts == 0 ), "num_fused_shared_experts is not supported in EP" self.num_fused_shared_experts = num_fused_shared_experts - self.num_experts_per_partition = self.num_experts // self.tp_size + self.num_experts_per_partition, self.expert_map = self.determine_expert_map() self.start_expert_id = self.tp_rank * self.num_experts_per_partition self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 @@ -215,6 +219,18 @@ class EPMoE(torch.nn.Module): self.use_block_quant = False self.block_shape = None self.activation_scheme = None + self.use_w4afp8 = False + elif isinstance(quant_config, W4AFp8Config): + self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod( + quant_config + ) + self.use_w4afp8 = True + self.use_fp8_w8a8 = False + self.use_block_quant = False + self.fp8_dtype = torch.float8_e4m3fn + self.w13_weight_scale = None + self.w2_weight_scale = None + self.activation_scheme = quant_config.moe_activation_scheme else: self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( quant_config @@ -228,6 +244,7 @@ class EPMoE(torch.nn.Module): ) self.fp8_dtype = torch.float8_e4m3fn self.activation_scheme = quant_config.activation_scheme + self.use_w4afp8 = False self.quant_method.create_weights( layer=self, @@ -253,6 +270,49 @@ class EPMoE(torch.nn.Module): self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, ) + # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43 + # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank. + def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]: + """ + Calculates how many experts should be assigned to each rank for EP and + creates a mapping from global to local expert index. Experts are + distributed evenly across ranks. Any remaining are assigned to the + last rank. + + Returns: + Tuple[int, Optional[torch.Tensor]]: A tuple containing: + - local_num_experts (int): The number of experts assigned + to the current rank. + - expert_map (Optional[torch.Tensor]): A tensor of shape + (global_num_experts,) mapping from global to local index. + Contains global_num_experts for experts not assigned to the current rank. + Returns None if ep_size is 1. + """ + ep_size = self.tp_size + ep_rank = self.tp_rank + global_num_experts = self.num_experts + + assert ep_size > 0 + if ep_size == 1: + return (global_num_experts, None) + + local_num_experts = global_num_experts // ep_size + + expert_map = torch.full( + (global_num_experts,), self.num_experts, dtype=torch.int32 + ) + if ep_rank < (ep_size - 1): + expert_map[ + ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts + ] = torch.arange(0, local_num_experts, dtype=torch.int32) + else: + local_num_experts = global_num_experts - ep_rank * local_num_experts + + expert_map[-local_num_experts:] = torch.arange( + 0, local_num_experts, dtype=torch.int32 + ) + return (local_num_experts, expert_map) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: return self.forward_deepgemm(hidden_states, router_logits) @@ -440,6 +500,51 @@ class EPMoE(torch.nn.Module): ), ) + if self.use_w4afp8: + local_topk_ids = topk_ids + if self.expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where( + self.expert_map[topk_ids] != self.num_experts, + self.expert_map[topk_ids], + self.num_experts, + ) + + output = cutlass_w4a8_moe( + self.start_expert_id, + self.end_expert_id, + self.num_experts, + hidden_states, + self.w13_weight, + self.w2_weight, + self.w13_weight_scale_inv, + self.w2_weight_scale_inv, + topk_weights, + topk_ids, + local_topk_ids, + self.quant_method.a_strides1, + self.quant_method.b_strides1, + self.quant_method.c_strides1, + self.quant_method.a_strides2, + self.quant_method.b_strides2, + self.quant_method.c_strides2, + self.quant_method.s_strides13, + self.quant_method.s_strides2, + self.quant_method.expert_offsets, + self.quant_method.problem_sizes1, + self.quant_method.problem_sizes2, + self.w13_input_scale, + self.w2_input_scale, + ) + return output + + if self.grouped_gemm_runner is None: + self.grouped_gemm_runner = GroupedGemmRunner( + hidden_states.device, + use_flashinfer=False, # TODO: use flashinfer + use_per_token_if_dynamic=self.use_per_token_if_dynamic, + ) + reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( topk_ids, self.num_experts ) @@ -449,7 +554,7 @@ class EPMoE(torch.nn.Module): device=hidden_states.device, dtype=( self.fp8_dtype - if (self.use_fp8_w8a8 and not self.use_block_quant) + if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant) else hidden_states.dtype ), ) @@ -656,6 +761,23 @@ class EPMoE(torch.nn.Module): ] ] + @classmethod + def make_expert_input_scale_params_mapping( + cls, + num_experts: int, + ) -> List[Tuple[str, str, int, str]]: + # (param_name, weight_name, expert_id, shard_id) + return [ + ( + "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_", + f"experts.{expert_id}.{shard_id}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id in ["w1", "w2", "w3"] + ] + def weight_loader( self, param: torch.nn.Parameter, @@ -727,6 +849,15 @@ class EPMoE(torch.nn.Module): # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: + if self.use_w4afp8: + if shard_id == "w1": + param_data[expert_id][0] = loaded_weight + elif shard_id == "w3": + param_data[expert_id][1] = loaded_weight + else: + param_data[expert_id] = loaded_weight + return + if ( (shard_id == "w1" or shard_id == "w3") and param_data[expert_id] != 1 @@ -752,6 +883,13 @@ class EPMoE(torch.nn.Module): ] = loaded_weight else: # w2 param_data[expert_id] = loaded_weight + elif self.use_w4afp8: + if shard_id == "w1": + param_data[expert_id][: self.intermediate_size, :] = loaded_weight + elif shard_id == "w3": + param_data[expert_id][self.intermediate_size :, :] = loaded_weight + else: + param_data[expert_id] = loaded_weight # If we are in merged column case (gate_up_proj) else: if shard_id in ("w1", "w3"): diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 68b4826d0..4ee498169 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.qoq import QoQConfig +from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "moe_wna16": MoeWNA16Config, "compressed-tensors": CompressedTensorsConfig, "qoq": QoQConfig, + "w4afp8": W4AFp8Config, } # VLLM-dependent quantization methods diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index fa67bba4d..4d886de91 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -200,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config): + def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -286,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase): if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if self.block_quant: - assert self.quant_config.activation_scheme == "dynamic" + if hasattr(self.quant_config, "activation_scheme"): + assert self.quant_config.activation_scheme == "dynamic" + elif hasattr(self.quant_config, "linear_activation_scheme"): + assert self.quant_config.linear_activation_scheme == "dynamic" scale = BlockQuantScaleParameter( data=torch.empty( (output_size_per_partition + block_n - 1) // block_n, @@ -308,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase): layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE - if self.quant_config.activation_scheme == "static": + if ( + hasattr(self.quant_config, "activation_scheme") + and self.quant_config.activation_scheme == "static" + ) or ( + hasattr(self.quant_config, "linear_activation_scheme") + and self.quant_config.linear_activation_scheme == "static" + ): scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, @@ -371,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase): layer.weight_scale = torch.nn.Parameter( layer.weight_scale.data, requires_grad=False ) - if self.quant_config.activation_scheme == "static": + if ( + hasattr(self.quant_config, "activation_scheme") + and self.quant_config.activation_scheme == "static" + ) or ( + hasattr(self.quant_config, "linear_activation_scheme") + and self.quant_config.linear_activation_scheme == "static" + ): layer.input_scale = torch.nn.Parameter( layer.input_scale.data, requires_grad=False ) @@ -405,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase): # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) - if self.quant_config.activation_scheme == "static": + if ( + hasattr(self.quant_config, "activation_scheme") + and self.quant_config.activation_scheme == "static" + ) or ( + hasattr(self.quant_config, "linear_activation_scheme") + and self.quant_config.linear_activation_scheme == "static" + ): layer.input_scale = Parameter( layer.input_scale.max(), requires_grad=False ) diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py new file mode 100644 index 000000000..c2820bdfc --- /dev/null +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -0,0 +1,264 @@ +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod +from sglang.srt.layers.quantization.utils import is_layer_skipped +from sglang.srt.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = logging.getLogger(__name__) + + +class W4AFp8Config(QuantizationConfig): + """Config class for MIXED_PRECISION W4AFp8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = True, + is_checkpoint_w4afp8_serialized: bool = True, + linear_activation_scheme: str = "dynamic", + moe_activation_scheme: str = "static", + ignored_layers: Optional[List[str]] = None, + weight_block_size: Optional[List[int]] = None, + group_size: int = 128, + ) -> None: + super().__init__() + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized + if is_checkpoint_w4afp8_serialized: + logger.warning("Detected w4afp8 checkpoint. Please note that") + if moe_activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}") + self.linear_activation_scheme = linear_activation_scheme + self.moe_activation_scheme = moe_activation_scheme + self.ignored_layers = ignored_layers or [] + self.weight_block_size = [128, 128] + self.group_size = group_size + + @classmethod + def get_name(cls) -> str: + return "w4afp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float8_e4m3fn] + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "fp8" in quant_method + is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method + linear_activation_scheme = "dynamic" + moe_activation_scheme = "static" + weight_block_size = [128, 128] + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized, + linear_activation_scheme=linear_activation_scheme, + moe_activation_scheme=moe_activation_scheme, + weight_block_size=weight_block_size, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return W4AFp8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class W4AFp8MoEMethod: + + def __init__(self, quant_config: W4AFp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + assert "weight_loader" in extra_weight_attrs + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + intermediate_size * 2, + hidden_size // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size // self.quant_config.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts_per_partition, + hidden_size, + intermediate_size // self.quant_config.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Input scales + w13_input_scale = torch.nn.Parameter( + torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.bfloat16), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + # Pre-populate the strides + device = layer.w13_weight.device + + self.a_strides1 = torch.full( + (num_experts_per_partition, 3), + hidden_size, + device=device, + dtype=torch.int64, + ) + self.c_strides1 = torch.full( + (num_experts_per_partition, 3), + 2 * intermediate_size, + device=device, + dtype=torch.int64, + ) + self.a_strides2 = torch.full( + (num_experts_per_partition, 3), + intermediate_size, + device=device, + dtype=torch.int64, + ) + self.c_strides2 = torch.full( + (num_experts_per_partition, 3), + hidden_size, + device=device, + dtype=torch.int64, + ) + self.b_strides1 = self.a_strides1 + self.s_strides13 = self.c_strides1 + self.b_strides2 = self.a_strides2 + self.s_strides2 = self.c_strides2 + + self.expert_offsets = torch.empty( + (num_experts_per_partition + 1), dtype=torch.int32, device=device + ) + self.problem_sizes1 = torch.empty( + (num_experts_per_partition, 3), dtype=torch.int32, device=device + ) + self.problem_sizes2 = torch.empty( + (num_experts_per_partition, 3), dtype=torch.int32, device=device + ) + + return + + def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor: + """Interleave scales in groups of 4 similar to TRT-LLM implementation.""" + s_shape = scales.shape + # Reshape to separate groups of 4 + scales_interleaved = scales.reshape( + s_shape[0], s_shape[1], (s_shape[2] // 4), 4 + ) + # Permute dimensions to interleave + scales_interleaved = scales_interleaved.permute(0, 2, 1, 3) + # Reshape back to original dimensions but with interleaved values + scales_interleaved = scales_interleaved.reshape( + s_shape[0], s_shape[2] // 4, s_shape[1] * 4 + ) + return scales_interleaved.contiguous() + + def process_weights_after_loading(self, layer: Module) -> None: + dtype = torch.bfloat16 + device = layer.w2_weight.device + + # Interleave w13_weight_scale (gate_up_proj) + w13_weight_scale = layer.w13_weight_scale_inv.to(dtype) + w13_weight_scale = self._interleave_scales(w13_weight_scale) + layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False) + + # Interleave w2_weight_scale (down_proj) + w2_weight_scale = layer.w2_weight_scale_inv.to(dtype) + w2_weight_scale = self._interleave_scales(w2_weight_scale) + layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False) + + # Process input scales + w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item() + new_w13_input_scale = torch.tensor( + [w13_input_scale_max], + dtype=dtype, + device=device, + ) + layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False) + + w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item() + new_w2_input_scale = torch.tensor( + [w2_input_scale_max], dtype=dtype, device=device + ) + layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index cc53f62c2..1784ee132 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2363,6 +2363,12 @@ class DeepseekV2ForCausalLM(nn.Module): ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, ) + if self.quant_config and self.quant_config.get_name() == "w4afp8": + expert_params_mapping += ( + get_moe_impl_class().make_expert_input_scale_params_mapping( + num_experts=self.config.n_routed_experts + ) + ) # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7a5db80a7..d9b2d09e9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -708,6 +708,7 @@ class ServerArgs: "w8a8_fp8", "moe_wna16", "qoq", + "w4afp8", ], help="The quantization method.", ) diff --git a/python/sglang/test/test_cutlass_w4a8_moe.py b/python/sglang/test/test_cutlass_w4a8_moe.py new file mode 100644 index 000000000..acf8a27b9 --- /dev/null +++ b/python/sglang/test/test_cutlass_w4a8_moe.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import pytest +import torch + +from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe +from sglang.srt.layers.moe.topk import select_experts + + +def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor: + if int4_values_interleaved.shape[-1] % 2 != 0: + raise ValueError( + "the last dim size of int4_values_interleaved tensor must be even." + ) + + input_tensor_int8 = int4_values_interleaved.to(torch.int8) + + low_nibbles = input_tensor_int8[..., 0::2] + high_nibbles = input_tensor_int8[..., 1::2] + + packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F) + + return packed_tensor.to(torch.int8) + + +def pack_interleave(num_experts, ref_weight, ref_scale): + n, k = ref_weight.shape[1], ref_weight.shape[2] + + weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda() + w_q = weight.view((num_experts, n, k // 2)).view(torch.int8) + w_q = w_q.contiguous() + + scale_interleaved = ref_scale.reshape( + ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4 + ) # [E, N, K/4, 4] + scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4] + scale_interleaved = scale_interleaved.reshape( + ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4 + ) # [E, K/4, N*4] + w_scale = scale_interleaved.contiguous() + + return w_q, w_scale + + +@pytest.mark.parametrize("M", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("N", [2048]) +@pytest.mark.parametrize("K", [7168]) +@pytest.mark.parametrize("E", [256]) +@pytest.mark.parametrize("ep_size", [8]) +@pytest.mark.parametrize("topk", [8]) +@pytest.mark.parametrize("group_size", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): + local_e = E // ep_size + + debug = False + if debug: + a = torch.ones((M, K), dtype=dtype, device="cuda") * 0.001 + ref_weight_1 = torch.ones((local_e, N * 2, K), dtype=torch.int8, device="cuda") + ref_weight_2 = torch.ones((local_e, K, N), dtype=torch.int8, device="cuda") + a1_scale = torch.ones(1, dtype=torch.float32, device="cuda") + a2_scale = torch.ones(1, dtype=torch.float32, device="cuda") + scale_1 = torch.ones( + (local_e, N * 2, K // group_size), dtype=dtype, device="cuda" + ) + scale_2 = torch.ones((local_e, K, N // group_size), dtype=dtype, device="cuda") + else: + a = torch.randn(M, K, dtype=dtype, device="cuda") + ref_weight_1 = torch.randint( + -8, 8, (local_e, N * 2, K), dtype=torch.int8, device="cuda" + ) + ref_weight_2 = torch.randint( + -8, 8, (local_e, K, N), dtype=torch.int8, device="cuda" + ) + affine_coeff = 0.005 + a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") + a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") + scale_1 = ( + torch.randn(local_e, N * 2, K // group_size, dtype=dtype, device="cuda") + * affine_coeff + ) + scale_2 = ( + torch.randn(local_e, K, N // group_size, dtype=dtype, device="cuda") + * affine_coeff + ) + + w1_q, w1_scale = pack_interleave(local_e, ref_weight_1, scale_1) + w2_q, w2_scale = pack_interleave(local_e, ref_weight_2, scale_2) + + device = "cuda" + a_strides1 = torch.full((local_e, 3), K, device=device, dtype=torch.int64) + c_strides1 = torch.full((local_e, 3), 2 * N, device=device, dtype=torch.int64) + a_strides2 = torch.full((local_e, 3), N, device=device, dtype=torch.int64) + c_strides2 = torch.full((local_e, 3), K, device=device, dtype=torch.int64) + b_strides1 = a_strides1 + s_strides13 = c_strides1 + b_strides2 = a_strides2 + s_strides2 = c_strides2 + + score = torch.randn((M, E), dtype=dtype, device=device) + topk_weights, topk_ids = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + use_grouped_topk=False, + renormalize=False, + ) + expert_map = torch.arange(E, dtype=torch.int32, device=device) + expert_map[local_e:] = E + + output = cutlass_moe( + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a_strides1, + b_strides1, + c_strides1, + a_strides2, + b_strides2, + c_strides2, + s_strides13, + s_strides2, + 0, + local_e - 1, + E, + a1_scale, + a2_scale, + expert_map, + ) + + ref_output = ref( + a, + local_e, + topk_weights, + topk_ids, + ref_weight_1, + ref_weight_2, + scale_1, + scale_2, + has_pre_quant=True, + has_alpha=True, + pre_quant_scale_1=a1_scale, + pre_quant_scale_2=a2_scale, + alpha_1=a1_scale, + alpha_2=a2_scale, + ) + + # compare + torch.cuda.synchronize() + + # compare final output + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + print("SUCCESS: Final output tensors are close.") + + +def cutlass_moe( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids_: torch.Tensor, + a_strides1: torch.Tensor, + b_strides1: torch.Tensor, + c_strides1: torch.Tensor, + a_strides2: torch.Tensor, + b_strides2: torch.Tensor, + c_strides2: torch.Tensor, + s_strides13: torch.Tensor, + s_strides2: torch.Tensor, + start_expert_id: int, + end_expert_id: int, + E: int, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +): + local_topk_ids = topk_ids_ + local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E) + device = a.device + + local_num_experts = end_expert_id - start_expert_id + 1 + expert_offsets = torch.empty( + (local_num_experts + 1), dtype=torch.int32, device=device + ) + problem_sizes1 = torch.empty( + (local_num_experts, 3), dtype=torch.int32, device=device + ) + problem_sizes2 = torch.empty( + (local_num_experts, 3), dtype=torch.int32, device=device + ) + return cutlass_w4a8_moe( + start_expert_id, + end_expert_id, + E, + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids_, + local_topk_ids, + a_strides1, + b_strides1, + c_strides1, + a_strides2, + b_strides2, + c_strides2, + s_strides13, + s_strides2, + expert_offsets, + problem_sizes1, + problem_sizes2, + a1_scale, + a2_scale, + apply_router_weight_on_input, + ) + + +def ref( + x: torch.Tensor, + num_experts: int, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ref_weight_1: torch.Tensor, + ref_weight_2: torch.Tensor, + ref_weight_scale_1: torch.Tensor, + ref_weight_scale_2: torch.Tensor, + has_pre_quant: bool = False, + has_alpha: bool = False, + pre_quant_scale_1: Optional[torch.Tensor] = None, + pre_quant_scale_2: Optional[torch.Tensor] = None, + alpha_1: Optional[torch.Tensor] = None, + alpha_2: Optional[torch.Tensor] = None, +): + results = torch.zeros_like(x) + dtype = x.dtype + for e_idx in range(num_experts): + mask = topk_ids == e_idx + activated_tokens = mask.sum(1).bool() + act = x[activated_tokens, :] + if act.shape[0] == 0: + continue + final_scale = (topk_weights * mask).sum(1)[activated_tokens].unsqueeze(1) + + act = ( + torch.clamp((act / pre_quant_scale_1.float()), -448.0, 448.0) + .to(torch.float8_e4m3fn) + .to(dtype) + ) + w3_w1 = ref_weight_1[e_idx] + ref_w_scale_repeat = ( + ref_weight_scale_1[e_idx].repeat_interleave(128, dim=1).to(float) + ) + w3_w1 = (w3_w1.to(float) * ref_w_scale_repeat).to(dtype) + fc1 = ((torch.matmul(act, w3_w1.T)) * alpha_1).to(torch.float16) + + gate, fc1 = fc1.chunk(2, dim=-1) + fc1 = fc1 * torch.nn.functional.silu(gate) + act = (fc1 / pre_quant_scale_2.float()).to(torch.float8_e4m3fn) + act = act.to(dtype) + + w2 = ref_weight_2[e_idx] + ref_w_scale_repeat = ( + ref_weight_scale_2[e_idx].repeat_interleave(128, dim=1).to(float) + ) + w2 = (w2.to(float) * ref_w_scale_repeat).to(dtype) + fc2 = (torch.matmul(act, w2.T) * alpha_2).to(torch.float16) + + results[activated_tokens, :] += (fc2 * final_scale).to(results.dtype) + + return results