From b4dff7f5ef9edd0ba637f54e5a42b5b072dea239 Mon Sep 17 00:00:00 2001 From: lizhigong <306128847@qq.com> Date: Tue, 21 Oct 2025 16:27:31 +0800 Subject: [PATCH] adaptation w4A8 quantization --- python/sglang/srt/_custom_ops.py | 31 ++ .../srt/layers/quantization/slimquant_w4a8.py | 415 ++++++++++++++++++ .../quantization/slimquant_w4a8_marlin.py | 318 ++++++++++++++ 3 files changed, 764 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/slimquant_w4a8.py create mode 100644 python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index de47707c1..cf63dd6c8 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -5,6 +5,15 @@ from typing import List, Optional, Tuple import torch from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu +try: + from lmslim import quant_ops + from lmslim import quant_tools +except Exception: + print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n") +try: + import lightop +except Exception: + print("INFO: Please install lightop if you want to infer awq of marlin.\n") logger = logging.getLogger(__name__) use_vllm_custom_allreduce = get_bool_env_var( @@ -175,3 +184,25 @@ def mscclpp_allreduce( context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int ) -> None: return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks) + +def triton_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + best_config:Optional[list] = None) -> torch.Tensor: + + return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config) + +def triton_int8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.float16, + device: str = "cuda:0", + best_config:Optional[list] = None, + repeat:Optional[int] = 2): + return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat) \ No newline at end of file diff --git a/python/sglang/srt/layers/quantization/slimquant_w4a8.py b/python/sglang/srt/layers/quantization/slimquant_w4a8.py new file mode 100644 index 000000000..c34ee6f02 --- /dev/null +++ b/python/sglang/srt/layers/quantization/slimquant_w4a8.py @@ -0,0 +1,415 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch +from sglang.srt.layers.linear import set_weight_attrs +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from torch.nn.parameter import Parameter +from sglang.srt.layers.linear import LinearBase +from sglang.srt.layers.quantization.base_config import LinearMethodBase, QuantizationConfig, QuantizeMethodBase, FusedMoEMethodBase +from sglang.srt.layers.parameter import ( + ChannelQuantScaleParameter, + _ColumnvLLMParameter, + RowvLLMParameter, +) +from lmslim.layers.gemm.int8_utils import ( + per_token_group_quant_int8, + per_token_quant_int8) +from sglang.srt import _custom_ops as ops +from vllm.utils import W8a8GetCacheJSON +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig + +import os + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + pass + +W8A8_TRITONJSON=W8a8GetCacheJSON() + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + scales= scale_a* scale_b.T + gemmout= torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32)) + output = (scales *gemmout).to(out_dtype) + if bias is not None: + output = output + bias + return output.to(out_dtype) + + +class SlimQuantW4A8Int8Config(QuantizationConfig): + """Config class for W8A8 Int8 Quantization. + + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_name(self) -> str: + return "slimquant_w4a8" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config": + return cls() + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + + if isinstance(layer, LinearBase): + return SlimQuantW4A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return SlimQuantW4A8Int8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: SlimQuantW4A8Int8Config): + self.quantization_config = quantization_config + self.tritonsingleton= W8a8GetCacheJSON() + self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + n=layer.weight.shape[0] + k=layer.weight.shape[1] + + if self.w8a8_strategy==1: + if {n,k} not in self.tritonsingleton.weight_shapes: + self.tritonsingleton.weight_shapes.append({n,k}) + json_file=self.tritonsingleton.get_w8a8json_name(n,k) + configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k) + + if configs_dict: + self.tritonsingleton.triton_json_dict.update(configs_dict) + + for key, value in configs_dict.items(): + m=int(key.split('_')[0]) + ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value) + else: + weight_data=layer.weight.data + _weight=weight_data.T.contiguous().reshape(n,-1) + layer.weight.data=_weight + + layer.weight = Parameter(layer.weight.t(), requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + weight_loader = extra_weight_attrs.get("weight_loader") + self.logical_widths = output_partition_sizes + + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + input_quant_args: Optional[list[torch.Tensor]] = None, + silu_quant_args: Optional[list[torch.Tensor]] = None + ): + # if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: + # assert len(input_quant_args) == 2 + # x_q, x_scale = input_quant_args + # elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None: + # x_q, x_scale = silu_quant_args + # else: + x_q, x_scale = per_token_quant_int8(x) + + if self.w8a8_strategy==1: + m=x_q.shape[0] + k=x_q.shape[1] + n=layer.weight.shape[1] + + if len(W8A8_TRITONJSON.triton_json_dict)==0: + best_config=None + + elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict: + if m<=16: + m_=m + elif m<=64: + m_= (m + 3) & -4 #取值到最近的4的倍数 + elif m<=160: + m_=(m + 7) & -8 + + elif m<200: #256 + m_=160 + elif m<480: #512 + m_=256 + elif m<960: #1024 + m_=512 + elif m<2048: + m_=1024 + elif m<4096: + m_=2048 + elif m<6000: + m_=4096 + else: + m_=8192 + + best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"] + + else: + best_config=None + + #if best_config==None: + # print("m:{},n:{},k:{}".format(m,n,k)) + # print("config not found!") + + return ops.triton_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias,best_config=best_config) + elif self.w8a8_strategy==2: + return ops.cutlass_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias) + else: + return ops.rocblas_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias) + + +class SlimQuantW4A8Int8MoEMethod: + """MoE method for W4A8INT8. + Supports loading INT8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + self.tritonsingleton= W8a8GetCacheJSON() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + tp_size = get_tensor_model_parallel_world_size() + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w13_input_scale = None + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = None + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E=layer.w13_weight.shape[0] + N1=layer.w13_weight.shape[1] + N2=layer.w2_weight.shape[1] + K=N1//2 + if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes: + self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K]) + + TOPK= self.tritonsingleton.topk + + json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True) + configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK) + + #warmup + if configs_dict: + self.tritonsingleton.triton_moejson_dict.update(configs_dict) + + layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + **_ + ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.") + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate + ) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w4a8=True, + per_channel_quant=True, + activation=activation, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + w1_scale=(layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + use_nn_moe=use_nn_moe, + ) diff --git a/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py b/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py new file mode 100644 index 000000000..0d3303380 --- /dev/null +++ b/python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py @@ -0,0 +1,318 @@ +from typing import Any, Callable, Dict, List, Optional +from sglang.srt.layers.moe.token_dispatcher.base import CombineInput +from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput +import torch +from sglang.srt import _custom_ops as ops +from sglang.srt.utils import set_weight_attrs +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from torch.nn.parameter import Parameter +from sglang.srt.layers.linear import LinearBase +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.quantization.w4a8_utils import w4a8_weight_repack_impl +from sglang.srt.layers.quantization.base_config import (FusedMoEMethodBase, QuantizeMethodBase) +from sglang.srt.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig + +try: + from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin +except Exception: + print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") + + +class MarlinMoeWorkspace: + """ + Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE. + global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device + """ + _instances = {} + def __new__(cls, device): + if device not in cls._instances: + instance = super().__new__(cls) + instance._initialized = False + cls._instances[device] = instance + return cls._instances[device] + + def __init__(self, device): + if self._initialized: + return + sms = torch.cuda.get_device_properties(device).multi_processor_count + self.workspace = torch.zeros( + 500, dtype=torch.int, device=device, requires_grad=False + ) + self.global_reduce_buffer = torch.zeros( + sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False + ) + self._initialized = True + + def get_buffers(self): + return self.workspace, self.global_reduce_buffer + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + scales= scale_a* scale_b.T + gemmout= torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32)) + output = (scales *gemmout).to(out_dtype) + if bias is not None: + output = output + bias + return output.to(out_dtype) + + +class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): + """Config class for W4A8 Int8 Quantization. + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_name(self) -> str: + return "slimquant_w4a8_marlin" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig": + return cls() + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[str]: + if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \ + and user_quant == "slimquant_w4a8_marlin": + return cls.get_name() + return None + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + + if isinstance(layer, LinearBase): + return SlimQuantW4A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return SlimQuantW4A8Int8MarlinMoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class SlimQuantW4A8Int8MarlinMoEMethod: + """MoE method for W4A8INT8 Marlin. + Supports loading INT8 checkpoints with static weight scale and + dynamic/static activation scale. + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + tp_size = get_tensor_model_parallel_world_size() + intermediate_size = intermediate_size_per_partition + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w13_input_scale = None + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = None + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + + layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False) + layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + + topk_weights, topk_ids, _ = topk_output + x, topk_weights = apply_topk_weights_cpu( + self.moe_runner_config.apply_router_weight_on_input, topk_weights, x + ) + workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() + output = fused_experts_impl_w4a8_marlin( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + workspace=workspace, + global_reduce_buffer=global_reduce_buffer, + inplace=True, + use_int4_w4a8=True, + per_channel_quant=True, + activation=layer.moe_runner_config.activation, + expert_map=layer.expert_map_gpu, + apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, + global_num_experts=layer.moe_runner_config.num_experts, + w1_scale=(layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + use_nn_moe=False, + ) + return StandardCombineInput(hidden_states=output) + # def _apply( + # self, + # layer: torch.nn.Module, + # x: torch.Tensor, + # router_logits: torch.Tensor, + # top_k: int, + # #renormalize: bool, + # #use_grouped_topk: bool = False, + # topk_group: Optional[int] = None, + # num_expert_group: Optional[int] = None, + # global_num_experts: int = -1, + # expert_map: Optional[torch.Tensor] = None, + # custom_routing_function: Optional[Callable] = None, + # scoring_func: str = "softmax", + # e_score_correction_bias: Optional[torch.Tensor] = None, + # apply_router_weight_on_input: bool = False, + # activation: str = "silu", + # enable_eplb: bool = False, + # use_nn_moe: Optional[bool] = False, + # routed_scaling_factor: Optional[float] = None, + # use_fused_gate: Optional[bool] = False, + # **_ + # ) -> torch.Tensor: + # from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) + # from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + # if enable_eplb: + # raise NotImplementedError( + # "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") + # # Expert selection + # topk_weights, topk_ids = FusedMoE.select_experts( + # hidden_states=x, + # router_logits=router_logits, + # #use_grouped_topk=use_grouped_topk, + # top_k=top_k, + # #renormalize=renormalize, + # topk_group=topk_group, + # num_expert_group=num_expert_group, + # custom_routing_function=custom_routing_function, + # scoring_func=scoring_func, + # e_score_correction_bias=e_score_correction_bias, + # routed_scaling_factor=routed_scaling_factor, + # use_fused_gate=use_fused_gate + # ) + # workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() + # return fused_experts_impl_w4a8_marlin( + # x, + # layer.w13_weight, + # layer.w2_weight, + # topk_weights=topk_weights, + # topk_ids=topk_ids, + # workspace=workspace, + # global_reduce_buffer=global_reduce_buffer, + # inplace=True, + # use_int4_w4a8=True, + # per_channel_quant=True, + # activation=activation, + # expert_map=expert_map, + # apply_router_weight_on_input=apply_router_weight_on_input, + # global_num_experts=global_num_experts, + # w1_scale=(layer.w13_weight_scale), + # w2_scale=(layer.w2_weight_scale), + # a1_scale=layer.w13_input_scale, + # a2_scale=layer.w2_input_scale, + # use_nn_moe=use_nn_moe, + # )