2024-12-06 15:42:10 +08:00
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2025-07-17 00:47:07 -07:00
from __future__ import annotations
2024-12-06 15:42:10 +08:00
import logging
2025-07-19 00:51:15 -07:00
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
2024-12-06 15:42:10 +08:00
import torch
2024-12-07 05:18:26 -08:00
import torch . nn . functional as F
2024-12-06 15:42:10 +08:00
from torch . nn import Module
from torch . nn . parameter import Parameter
2025-03-18 06:51:59 +08:00
try :
from vllm . model_executor . layers . quantization . utils . marlin_utils_fp8 import (
apply_fp8_marlin_linear ,
prepare_fp8_layer_for_marlin ,
)
MARLIN_FP8_AVAILABLE = True
except ImportError :
MARLIN_FP8_AVAILABLE = False
2025-04-16 15:26:49 -07:00
def dummy_func ( * args , * * kwargs ) :
raise ImportError (
" marlin FP8 requires some operators from vllm. Please install vllm. "
)
2025-03-18 06:51:59 +08:00
2025-04-16 15:26:49 -07:00
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
2025-03-18 06:51:59 +08:00
2025-01-17 22:31:51 +08:00
from sglang . srt . distributed import get_tensor_model_parallel_world_size
2025-07-04 00:51:09 +08:00
from sglang . srt . layers . amx_utils import _amx_process_weight_after_loading
2025-09-05 21:09:09 -07:00
from sglang . srt . layers . moe import MoeRunner , MoeRunnerBackend , MoeRunnerConfig
from sglang . srt . layers . moe . moe_runner . triton import TritonMoeQuantInfo
from sglang . srt . layers . moe . token_dispatcher . base import DispatchOutputChecker
2025-03-09 16:03:32 +08:00
from sglang . srt . layers . parameter import (
BlockQuantScaleParameter ,
ModelWeightParameter ,
PerTensorScaleParameter ,
)
2024-12-06 15:42:10 +08:00
from sglang . srt . layers . quantization . base_config import (
2025-07-17 00:47:07 -07:00
FusedMoEMethodBase ,
LinearMethodBase ,
2024-12-06 15:42:10 +08:00
QuantizationConfig ,
QuantizeMethodBase ,
)
2025-04-16 15:26:49 -07:00
from sglang . srt . layers . quantization . fp8_kernel import (
2025-05-08 08:28:24 +08:00
fp8_dtype ,
is_fp8_fnuz ,
2025-04-16 15:26:49 -07:00
per_token_group_quant_fp8 ,
scaled_fp8_quant ,
)
2024-12-26 00:02:14 +08:00
from sglang . srt . layers . quantization . fp8_utils import (
2025-03-09 16:03:32 +08:00
apply_fp8_linear ,
2025-08-19 02:53:15 +01:00
can_auto_enable_marlin_fp8 ,
2025-03-09 16:03:32 +08:00
cutlass_fp8_supported ,
2025-05-29 00:15:11 -07:00
dispatch_w8a8_block_fp8_linear ,
2025-03-09 16:03:32 +08:00
input_to_float8 ,
2024-12-26 00:02:14 +08:00
normalize_e4m3fn_to_e4m3fnuz ,
)
2025-04-16 15:26:49 -07:00
from sglang . srt . layers . quantization . kv_cache import BaseKVCacheMethod
2025-07-17 00:47:07 -07:00
from sglang . srt . layers . quantization . unquant import UnquantizedLinearMethod
2025-04-16 15:26:49 -07:00
from sglang . srt . layers . quantization . utils import (
all_close_1d ,
convert_to_channelwise ,
is_layer_skipped ,
per_tensor_dequantize ,
requantize_with_max_scale ,
)
2024-12-06 15:42:10 +08:00
from sglang . srt . utils import (
2025-06-18 13:11:50 +08:00
cpu_has_amx_support ,
2024-12-06 15:42:10 +08:00
get_bool_env_var ,
2025-06-18 13:11:50 +08:00
is_cpu ,
2025-03-18 06:51:59 +08:00
is_cuda ,
2024-12-06 15:42:10 +08:00
is_hip ,
2025-06-18 02:24:10 +08:00
is_npu ,
2025-08-27 03:34:29 -07:00
is_sm90_supported ,
is_sm100_supported ,
2025-05-07 18:54:50 -07:00
log_info_on_rank0 ,
2025-07-31 19:56:34 -07:00
next_power_of_2 ,
2024-12-06 15:42:10 +08:00
print_warning_once ,
set_weight_attrs ,
2025-07-04 00:51:09 +08:00
use_intel_amx_backend ,
2024-12-06 15:42:10 +08:00
)
2025-07-17 00:47:07 -07:00
if TYPE_CHECKING :
2025-09-05 21:09:09 -07:00
from sglang . srt . layers . moe . token_dispatcher import (
CombineInput ,
DispatchOutput ,
StandardDispatchOutput ,
)
2025-07-19 00:51:15 -07:00
from sglang . srt . layers . moe . topk import TopKOutput
2025-07-17 00:47:07 -07:00
from sglang . srt . layers . quantization . w4afp8 import W4AFp8Config
2025-03-11 18:12:56 -07:00
_is_hip = is_hip ( )
2025-04-16 15:26:49 -07:00
_is_cuda = is_cuda ( )
2025-06-18 02:24:10 +08:00
_is_npu = is_npu ( )
2025-06-18 13:11:50 +08:00
_is_cpu_amx_available = cpu_has_amx_support ( )
_is_cpu = is_cpu ( )
2025-01-13 18:12:44 +08:00
2025-05-08 08:28:24 +08:00
_is_fp8_fnuz = is_fp8_fnuz ( )
2025-06-05 23:00:18 -07:00
_use_hip_int4 = get_bool_env_var ( " SGLANG_INT4_WEIGHT " )
_use_aiter = get_bool_env_var ( " SGLANG_USE_AITER " ) and _is_hip
2025-05-08 08:28:24 +08:00
2025-07-07 02:09:58 -07:00
if _is_hip and ( _use_aiter or _use_hip_int4 ) :
2025-04-28 11:01:20 -07:00
from aiter import ActivationType , QuantType
2025-06-17 13:26:51 +08:00
from aiter . fused_moe import fused_moe
2025-03-04 03:00:46 -08:00
from aiter . ops . shuffle import shuffle_weight
2025-04-16 15:26:49 -07:00
ACTIVATION_SCHEMES = [ " static " , " dynamic " ]
2025-03-18 06:51:59 +08:00
2024-12-06 15:42:10 +08:00
logger = logging . getLogger ( __name__ )
class Fp8Config ( QuantizationConfig ) :
""" Config class for FP8. """
def __init__ (
self ,
is_checkpoint_fp8_serialized : bool = False ,
activation_scheme : str = " dynamic " ,
ignored_layers : Optional [ List [ str ] ] = None ,
2024-12-26 00:02:14 +08:00
weight_block_size : List [ int ] = None ,
2024-12-06 15:42:10 +08:00
) - > None :
self . is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized :
2025-05-07 18:54:50 -07:00
log_info_on_rank0 ( logger , " Detected fp8 checkpoint. " )
2024-12-06 15:42:10 +08:00
if activation_scheme not in ACTIVATION_SCHEMES :
raise ValueError ( f " Unsupported activation scheme { activation_scheme } " )
self . activation_scheme = activation_scheme
self . ignored_layers = ignored_layers or [ ]
2024-12-26 00:02:14 +08:00
if weight_block_size is not None :
if not is_checkpoint_fp8_serialized :
raise ValueError (
f " The block-wise quantization only supports fp8-serialized checkpoint for now. "
)
if len ( weight_block_size ) != 2 :
raise ValueError (
f " The quantization block size of weight must have 2 dimensions, but got { len ( weight_block_size ) } dimensions. "
)
if activation_scheme != " dynamic " :
raise ValueError (
f " The block-wise quantization only supports dynamic activation scheme for now, but got { activation_scheme } activation scheme. "
)
self . weight_block_size = weight_block_size
2024-12-06 15:42:10 +08:00
@classmethod
def get_name ( cls ) - > str :
return " fp8 "
@classmethod
def get_supported_act_dtypes ( cls ) - > List [ torch . dtype ] :
return [ torch . bfloat16 , torch . half ]
@classmethod
def get_min_capability ( cls ) - > int :
2025-03-12 23:45:52 -07:00
return 80
2024-12-06 15:42:10 +08:00
@classmethod
def get_config_filenames ( cls ) - > List [ str ] :
return [ ]
@classmethod
2025-07-17 00:47:07 -07:00
def from_config ( cls , config : Dict [ str , Any ] ) - > Fp8Config :
2024-12-06 15:42:10 +08:00
quant_method = cls . get_from_keys ( config , [ " quant_method " ] )
is_checkpoint_fp8_serialized = " fp8 " in quant_method
activation_scheme = cls . get_from_keys ( config , [ " activation_scheme " ] )
ignored_layers = cls . get_from_keys_or ( config , [ " ignored_layers " ] , None )
2024-12-26 00:02:14 +08:00
weight_block_size = cls . get_from_keys_or ( config , [ " weight_block_size " ] , None )
2024-12-06 15:42:10 +08:00
return cls (
is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized ,
activation_scheme = activation_scheme ,
ignored_layers = ignored_layers ,
2024-12-26 00:02:14 +08:00
weight_block_size = weight_block_size ,
2024-12-06 15:42:10 +08:00
)
def get_quant_method (
self , layer : torch . nn . Module , prefix : str
2025-07-17 00:47:07 -07:00
) - > Optional [ QuantizeMethodBase ] :
from sglang . srt . layers . linear import LinearBase
2024-12-24 01:10:22 +08:00
from sglang . srt . layers . moe . fused_moe_triton import FusedMoE
2024-12-07 19:28:53 +08:00
2024-12-06 15:42:10 +08:00
if isinstance ( layer , LinearBase ) :
if is_layer_skipped ( prefix , self . ignored_layers ) :
return UnquantizedLinearMethod ( )
return Fp8LinearMethod ( self )
elif isinstance ( layer , FusedMoE ) :
return Fp8MoEMethod ( self )
return None
def get_scaled_act_names ( self ) - > List [ str ] :
return [ ]
class Fp8LinearMethod ( LinearMethodBase ) :
""" Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic / static activation scale .
Also supports loading quantized FP16 / BF16 model checkpoints with dynamic
activation scaling . The weight scaling factor will be initialized after
the model weights are loaded .
Limitations :
1. Only support per - tensor quantization due to torch . _scaled_mm support .
2. Only support float8_e4m3fn data type due to the limitation of
torch . _scaled_mm ( https : / / github . com / pytorch / pytorch / blob / 2e48 b39603411a41c5025efbe52f89560b827825 / aten / src / ATen / native / cuda / Blas . cpp #L854-L856)
Args :
quant_config : The quantization config .
"""
2025-07-17 00:47:07 -07:00
def __init__ ( self , quant_config : Union [ Fp8Config , W4AFp8Config ] ) :
2024-12-06 15:42:10 +08:00
self . quant_config = quant_config
self . cutlass_fp8_supported = cutlass_fp8_supported ( )
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
2025-08-19 02:53:15 +01:00
self . use_marlin = False
if _is_cuda and MARLIN_FP8_AVAILABLE :
force_marlin = get_bool_env_var ( " SGLANG_FORCE_FP8_MARLIN " )
auto_enable = can_auto_enable_marlin_fp8 ( )
self . use_marlin = force_marlin or auto_enable
2024-12-06 15:42:10 +08:00
2024-12-26 00:02:14 +08:00
self . block_quant = self . quant_config . weight_block_size is not None
2025-05-29 00:15:11 -07:00
self . w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear ( )
2024-12-06 15:42:10 +08:00
def create_weights (
self ,
layer : torch . nn . Module ,
input_size_per_partition : int ,
output_partition_sizes : List [ int ] ,
input_size : int ,
output_size : int ,
params_dtype : torch . dtype ,
* * extra_weight_attrs ,
) :
output_size_per_partition = sum ( output_partition_sizes )
weight_loader = extra_weight_attrs . get ( " weight_loader " )
2024-12-26 00:02:14 +08:00
tp_size = get_tensor_model_parallel_world_size ( )
if self . block_quant :
block_n , block_k = (
self . quant_config . weight_block_size [ 0 ] ,
self . quant_config . weight_block_size [ 1 ] ,
)
# Required by row parallel
if tp_size > 1 and input_size / / input_size_per_partition == tp_size :
if input_size_per_partition % block_k != 0 :
raise ValueError (
f " Weight input_size_per_partition = "
f " { input_size_per_partition } is not divisible by "
f " weight quantization block_k = { block_k } . "
)
2025-05-11 00:55:00 -04:00
# Required by column parallel or enabling merged weights
2024-12-26 00:02:14 +08:00
if (
tp_size > 1 and output_size / / output_size_per_partition == tp_size
) or len ( output_partition_sizes ) > 1 :
for output_partition_size in output_partition_sizes :
if output_partition_size % block_n != 0 :
raise ValueError (
f " Weight output_partition_size = "
f " { output_partition_size } is not divisible by "
f " weight quantization block_n = { block_n } . "
)
2024-12-06 15:42:10 +08:00
layer . logical_widths = output_partition_sizes
layer . input_size_per_partition = input_size_per_partition
layer . output_size_per_partition = output_size_per_partition
layer . orig_dtype = params_dtype
# WEIGHT
weight_dtype = (
torch . float8_e4m3fn
if self . quant_config . is_checkpoint_fp8_serialized
else params_dtype
)
weight = ModelWeightParameter (
data = torch . empty (
output_size_per_partition , input_size_per_partition , dtype = weight_dtype
) ,
input_dim = 1 ,
output_dim = 0 ,
weight_loader = weight_loader ,
)
layer . register_parameter ( " weight " , weight )
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self . quant_config . is_checkpoint_fp8_serialized :
# WEIGHT SCALE
2024-12-26 00:02:14 +08:00
if self . block_quant :
2025-07-08 05:47:21 +08:00
if hasattr ( self . quant_config , " activation_scheme " ) :
assert self . quant_config . activation_scheme == " dynamic "
elif hasattr ( self . quant_config , " linear_activation_scheme " ) :
assert self . quant_config . linear_activation_scheme == " dynamic "
2024-12-26 00:02:14 +08:00
scale = BlockQuantScaleParameter (
data = torch . empty (
( output_size_per_partition + block_n - 1 ) / / block_n ,
( input_size_per_partition + block_k - 1 ) / / block_k ,
dtype = torch . float32 ,
) ,
input_dim = 1 ,
output_dim = 0 ,
weight_loader = weight_loader ,
)
scale [ : ] = torch . finfo ( torch . float32 ) . min
layer . register_parameter ( " weight_scale_inv " , scale )
else :
scale = PerTensorScaleParameter (
data = torch . empty ( len ( output_partition_sizes ) , dtype = torch . float32 ) ,
weight_loader = weight_loader ,
)
scale [ : ] = torch . finfo ( torch . float32 ) . min
layer . register_parameter ( " weight_scale " , scale )
2024-12-06 15:42:10 +08:00
# INPUT ACTIVATION SCALE
2025-07-08 05:47:21 +08:00
if (
hasattr ( self . quant_config , " activation_scheme " )
and self . quant_config . activation_scheme == " static "
) or (
hasattr ( self . quant_config , " linear_activation_scheme " )
and self . quant_config . linear_activation_scheme == " static "
) :
2024-12-06 15:42:10 +08:00
scale = PerTensorScaleParameter (
data = torch . empty ( len ( output_partition_sizes ) , dtype = torch . float32 ) ,
weight_loader = weight_loader ,
)
scale [ : ] = torch . finfo ( torch . float32 ) . min
layer . register_parameter ( " input_scale " , scale )
else :
layer . register_parameter ( " input_scale " , None )
def process_weights_after_loading ( self , layer : Module ) - > None :
2024-12-26 00:02:14 +08:00
if self . block_quant :
2024-12-29 03:23:39 -08:00
# If ROCm, normalize the weights and scales to e4m3fnuz
2025-05-08 08:28:24 +08:00
if _is_fp8_fnuz :
2024-12-29 03:23:39 -08:00
# activation_scheme: dynamic
weight , weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
weight = layer . weight ,
weight_scale = layer . weight_scale_inv ,
input_scale = None ,
)
layer . input_scale = None
2025-06-28 10:04:29 +08:00
elif _is_cpu :
assert (
_is_cpu_amx_available
) , " Fp8LinearMethod on CPU requires that CPU has AMX support "
2025-07-04 00:51:09 +08:00
_amx_process_weight_after_loading ( layer , [ " weight " ] )
2025-06-28 10:04:29 +08:00
return
2025-01-31 19:56:02 +08:00
else :
2025-05-08 08:28:24 +08:00
weight , weight_scale = layer . weight . data , layer . weight_scale_inv . data
2025-08-19 02:53:15 +01:00
layer . weight = Parameter ( weight , requires_grad = False )
layer . weight_scale_inv = Parameter ( weight_scale , requires_grad = False )
else :
layer . weight = Parameter ( layer . weight . data , requires_grad = False )
# If checkpoint not serialized fp8, quantize the weights.
if not self . quant_config . is_checkpoint_fp8_serialized :
if self . cutlass_fp8_supported or self . use_marlin :
# apply per-channel quantization default as
# cutlass sgl-kernel and marlin only support per-channel scale
qweight , weight_scale = per_token_group_quant_fp8 (
layer . weight , layer . weight . shape [ - 1 ]
)
weight_scale = weight_scale . t ( ) . contiguous ( )
else :
# per-tensor quantization
qweight , weight_scale = input_to_float8 ( layer . weight )
# Update the layer with the new values.
layer . weight = Parameter ( qweight . t ( ) , requires_grad = False )
layer . weight_scale = Parameter ( weight_scale , requires_grad = False )
layer . input_scale = None
2025-04-16 15:26:49 -07:00
2025-08-19 02:53:15 +01:00
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
2025-03-09 16:03:32 +08:00
else :
2025-08-19 02:53:15 +01:00
layer . weight_scale = Parameter (
layer . weight_scale . data , requires_grad = False
2024-12-06 15:42:10 +08:00
)
2025-08-19 02:53:15 +01:00
if (
hasattr ( self . quant_config , " activation_scheme " )
and self . quant_config . activation_scheme == " static "
) or (
hasattr ( self . quant_config , " linear_activation_scheme " )
and self . quant_config . linear_activation_scheme == " static "
) :
layer . input_scale = Parameter (
layer . input_scale . data , requires_grad = False
)
2025-03-09 16:03:32 +08:00
2025-08-19 02:53:15 +01:00
# cutlass sgl-kernel and marlin only support per-channel scale
if self . cutlass_fp8_supported or self . use_marlin :
weight = layer . weight
weight_scale = convert_to_channelwise (
layer . weight_scale , layer . logical_widths
)
else :
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer . weight
weight_scale = layer . weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz :
weight , weight_scale , input_scale = (
normalize_e4m3fn_to_e4m3fnuz (
weight = weight ,
weight_scale = weight_scale ,
input_scale = layer . input_scale ,
)
)
if input_scale is not None :
layer . input_scale = Parameter (
input_scale , requires_grad = False
)
weight_scale , weight = requantize_with_max_scale (
2024-12-06 15:42:10 +08:00
weight = weight ,
weight_scale = weight_scale ,
2025-08-19 02:53:15 +01:00
logical_widths = layer . logical_widths ,
2024-12-06 15:42:10 +08:00
)
2025-08-19 02:53:15 +01:00
# Update layer with new values.
layer . weight = Parameter ( weight . t ( ) , requires_grad = False )
layer . weight_scale = Parameter ( weight_scale , requires_grad = False )
if (
hasattr ( self . quant_config , " activation_scheme " )
and self . quant_config . activation_scheme == " static "
) or (
hasattr ( self . quant_config , " linear_activation_scheme " )
and self . quant_config . linear_activation_scheme == " static "
) :
layer . input_scale = Parameter (
layer . input_scale . max ( ) , requires_grad = False
)
2024-12-06 15:42:10 +08:00
if self . use_marlin :
2025-08-19 02:53:15 +01:00
if self . block_quant :
layer . weight_block_size = self . quant_config . weight_block_size
prepare_fp8_layer_for_marlin ( layer , not self . block_quant )
2025-04-16 15:26:49 -07:00
# Activations not quantized for marlin.
del layer . input_scale
2024-12-06 15:42:10 +08:00
def apply (
self ,
layer : torch . nn . Module ,
x : torch . Tensor ,
bias : Optional [ torch . Tensor ] = None ,
) - > torch . Tensor :
if self . use_marlin :
2025-04-16 15:26:49 -07:00
return apply_fp8_marlin_linear (
input = x ,
weight = layer . weight ,
weight_scale = layer . weight_scale ,
workspace = layer . workspace ,
size_n = layer . output_size_per_partition ,
size_k = layer . input_size_per_partition ,
bias = bias ,
)
2024-12-06 15:42:10 +08:00
2024-12-26 00:02:14 +08:00
if self . block_quant :
2025-07-04 00:51:09 +08:00
if use_intel_amx_backend ( layer ) :
2025-06-28 10:04:29 +08:00
return torch . ops . sgl_kernel . fp8_scaled_mm_cpu (
x ,
layer . weight ,
layer . weight_scale_inv ,
self . quant_config . weight_block_size ,
bias ,
x . dtype ,
True , # is_vnni
)
2025-05-29 00:15:11 -07:00
return self . w8a8_block_fp8_linear (
2024-12-26 00:02:14 +08:00
input = x ,
weight = layer . weight ,
block_size = self . quant_config . weight_block_size ,
weight_scale = layer . weight_scale_inv ,
2024-12-29 03:23:39 -08:00
input_scale = None ,
2024-12-26 00:02:14 +08:00
bias = bias ,
)
2024-12-06 15:42:10 +08:00
return apply_fp8_linear (
input = x ,
weight = layer . weight ,
weight_scale = layer . weight_scale ,
input_scale = layer . input_scale ,
bias = bias ,
cutlass_fp8_supported = self . cutlass_fp8_supported ,
use_per_token_if_dynamic = False ,
)
2025-07-31 19:56:34 -07:00
def get_tile_tokens_dim ( num_tokens , top_k , num_experts ) :
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = ( num_tokens * top_k ) / / num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2 ( num_tokens_per_expert )
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min ( max ( tile_tokens_dim , 8 ) , 64 )
return tile_tokens_dim
2025-07-17 00:47:07 -07:00
class Fp8MoEMethod ( FusedMoEMethodBase ) :
2024-12-06 15:42:10 +08:00
""" MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic / static activation scale .
Also supports loading quantized FP16 / BF16 model checkpoints with dynamic
activation scaling . The weight scaling factor will be initialized after
the model weights are loaded .
Args :
quant_config : The quantization config .
"""
2025-07-17 00:47:07 -07:00
def __init__ ( self , quant_config : Fp8Config ) :
2024-12-06 15:42:10 +08:00
self . quant_config = quant_config
2024-12-26 00:02:14 +08:00
self . block_quant = self . quant_config . weight_block_size is not None
2025-05-16 13:14:07 -07:00
self . cutlass_fp8_supported = cutlass_fp8_supported ( )
2025-08-20 15:10:16 -07:00
self . use_cutlass_fused_experts_fp8 = (
get_bool_env_var ( " SGLANG_CUTLASS_MOE " )
and self . cutlass_fp8_supported
and self . block_quant
and ( is_sm100_supported ( ) or is_sm90_supported ( ) )
)
2024-12-06 15:42:10 +08:00
def create_weights (
self ,
layer : Module ,
num_experts : int ,
hidden_size : int ,
2025-09-05 21:09:09 -07:00
intermediate_size_per_partition : int ,
2024-12-06 15:42:10 +08:00
params_dtype : torch . dtype ,
* * extra_weight_attrs ,
) :
2024-12-24 01:10:22 +08:00
from sglang . srt . layers . moe . fused_moe_triton import FusedMoeWeightScaleSupported
2024-12-06 15:42:10 +08:00
if self . quant_config . is_checkpoint_fp8_serialized :
2025-06-05 23:00:18 -07:00
params_dtype = torch . uint32 if _use_hip_int4 else torch . float8_e4m3fn
2024-12-26 00:02:14 +08:00
tp_size = get_tensor_model_parallel_world_size ( )
if self . block_quant :
block_n , block_k = (
self . quant_config . weight_block_size [ 0 ] ,
self . quant_config . weight_block_size [ 1 ] ,
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
2025-05-11 00:55:00 -04:00
# Required by column parallel or enabling merged weights
2025-09-05 21:09:09 -07:00
if intermediate_size_per_partition % block_n != 0 :
2024-12-26 00:02:14 +08:00
raise ValueError (
f " The output_size of gate ' s and up ' s weight = "
2025-09-05 21:09:09 -07:00
f " { intermediate_size_per_partition } is not divisible by "
2024-12-26 00:02:14 +08:00
f " weight quantization block_n = { block_n } . "
)
if tp_size > 1 :
# Required by row parallel
2025-09-05 21:09:09 -07:00
if intermediate_size_per_partition % block_k != 0 :
2024-12-26 00:02:14 +08:00
raise ValueError (
f " The input_size of down ' s weight = "
2025-09-05 21:09:09 -07:00
f " { intermediate_size_per_partition } is not divisible by "
2024-12-26 00:02:14 +08:00
f " weight quantization block_k = { block_k } . "
)
2024-12-06 15:42:10 +08:00
# WEIGHTS
2025-06-05 23:00:18 -07:00
if _is_hip and _use_hip_int4 :
2025-03-06 15:33:02 -08:00
# INT4 MoE weight - INT32 packed
w13_weight = torch . nn . Parameter (
torch . empty (
num_experts ,
2025-09-05 21:09:09 -07:00
2 * intermediate_size_per_partition ,
2025-03-06 15:33:02 -08:00
hidden_size / / 8 ,
dtype = params_dtype ,
) ,
requires_grad = False ,
)
w2_weight = torch . nn . Parameter (
torch . empty (
2025-09-05 21:09:09 -07:00
num_experts ,
hidden_size ,
intermediate_size_per_partition / / 8 ,
dtype = params_dtype ,
2025-03-06 15:33:02 -08:00
) ,
requires_grad = False ,
)
else :
w13_weight = torch . nn . Parameter (
torch . empty (
2025-09-05 21:09:09 -07:00
num_experts ,
2 * intermediate_size_per_partition ,
hidden_size ,
dtype = params_dtype ,
2025-03-06 15:33:02 -08:00
) ,
requires_grad = False ,
)
w2_weight = torch . nn . Parameter (
torch . empty (
2025-09-05 21:09:09 -07:00
num_experts ,
hidden_size ,
intermediate_size_per_partition ,
dtype = params_dtype ,
2025-03-06 15:33:02 -08:00
) ,
requires_grad = False ,
)
2024-12-06 15:42:10 +08:00
layer . register_parameter ( " w13_weight " , w13_weight )
set_weight_attrs ( w13_weight , extra_weight_attrs )
layer . register_parameter ( " w2_weight " , w2_weight )
set_weight_attrs ( w2_weight , extra_weight_attrs )
# WEIGHT_SCALES
2024-12-26 00:02:14 +08:00
if self . block_quant :
w13_weight_scale = torch . nn . Parameter (
torch . ones (
num_experts ,
2025-09-05 21:09:09 -07:00
2 * ( ( intermediate_size_per_partition + block_n - 1 ) / / block_n ) ,
2024-12-26 00:02:14 +08:00
( hidden_size + block_k - 1 ) / / block_k ,
dtype = torch . float32 ,
) ,
requires_grad = False ,
)
w2_weight_scale = torch . nn . Parameter (
torch . ones (
num_experts ,
( hidden_size + block_n - 1 ) / / block_n ,
2025-09-05 21:09:09 -07:00
( intermediate_size_per_partition + block_k - 1 ) / / block_k ,
2024-12-26 00:02:14 +08:00
dtype = torch . float32 ,
) ,
requires_grad = False ,
)
layer . register_parameter ( " w13_weight_scale_inv " , w13_weight_scale )
layer . register_parameter ( " w2_weight_scale_inv " , w2_weight_scale )
assert self . quant_config . activation_scheme == " dynamic "
2025-09-05 22:41:22 -07:00
if self . use_cutlass_fused_experts_fp8 :
2025-05-16 13:14:07 -07:00
self . ab_strides1 = torch . full (
( num_experts , ) ,
hidden_size ,
device = w13_weight . device ,
dtype = torch . int64 ,
)
self . c_strides1 = torch . full (
( num_experts , ) ,
2025-09-05 21:09:09 -07:00
2 * intermediate_size_per_partition ,
2025-05-16 13:14:07 -07:00
device = w13_weight . device ,
dtype = torch . int64 ,
)
self . ab_strides2 = torch . full (
( num_experts , ) ,
2025-09-05 21:09:09 -07:00
intermediate_size_per_partition ,
2025-05-16 13:14:07 -07:00
device = w2_weight . device ,
dtype = torch . int64 ,
)
self . c_strides2 = torch . full (
( num_experts , ) ,
2025-09-07 19:43:59 -07:00
hidden_size ,
2025-05-16 13:14:07 -07:00
device = w2_weight . device ,
dtype = torch . int64 ,
)
self . workspace = torch . empty (
90000 , device = w13_weight . device , dtype = torch . uint8
)
self . a_ptr = torch . empty (
num_experts , device = w13_weight . device , dtype = torch . int64
)
self . b_ptr = torch . empty (
num_experts , device = w13_weight . device , dtype = torch . int64
)
self . out_ptr = torch . empty (
num_experts , device = w13_weight . device , dtype = torch . int64
)
self . a_scales_ptr = torch . empty (
num_experts , device = w13_weight . device , dtype = torch . int64
)
self . b_scales_ptr = torch . empty (
num_experts , device = w13_weight . device , dtype = torch . int64
)
self . expert_offsets = torch . empty (
num_experts + 1 , device = w13_weight . device , dtype = torch . int32
)
self . problem_sizes1 = torch . empty (
num_experts , 3 , device = w13_weight . device , dtype = torch . int32
)
self . problem_sizes2 = torch . empty (
num_experts , 3 , device = w13_weight . device , dtype = torch . int32
)
2024-12-26 00:02:14 +08:00
else :
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch . nn . Parameter (
torch . ones ( num_experts , 2 , dtype = torch . float32 ) , requires_grad = False
)
w2_weight_scale = torch . nn . Parameter (
torch . ones ( num_experts , dtype = torch . float32 ) , requires_grad = False
)
layer . register_parameter ( " w13_weight_scale " , w13_weight_scale )
layer . register_parameter ( " w2_weight_scale " , w2_weight_scale )
2025-03-04 03:00:46 -08:00
2025-06-05 23:00:18 -07:00
if _is_hip : # _use_aiter: TODO: add check back after triton kernel
2025-03-04 03:00:46 -08:00
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch . nn . Parameter (
2025-09-05 21:09:09 -07:00
torch . ones (
num_experts ,
2 * intermediate_size_per_partition ,
dtype = torch . float32 ,
) ,
2025-03-04 03:00:46 -08:00
requires_grad = False ,
)
w2_weight_scale1 = torch . nn . Parameter (
torch . ones ( num_experts , hidden_size , dtype = torch . float32 ) ,
requires_grad = False ,
)
layer . register_parameter ( " w13_weight_scale1 " , w13_weight_scale1 )
layer . register_parameter ( " w2_weight_scale1 " , w2_weight_scale1 )
2024-12-06 15:42:10 +08:00
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs . update (
2024-12-26 00:02:14 +08:00
{ " quant_method " : FusedMoeWeightScaleSupported . BLOCK . value }
if self . block_quant
else { " quant_method " : FusedMoeWeightScaleSupported . TENSOR . value }
2024-12-06 15:42:10 +08:00
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self . quant_config . is_checkpoint_fp8_serialized :
set_weight_attrs ( w13_weight_scale , extra_weight_attrs )
set_weight_attrs ( w2_weight_scale , extra_weight_attrs )
2025-06-05 23:00:18 -07:00
if _is_hip and _use_hip_int4 :
2025-03-06 15:33:02 -08:00
extra_weight_attrs . update (
{ " quant_method " : FusedMoeWeightScaleSupported . CHANNEL . value }
)
set_weight_attrs ( w13_weight_scale1 , extra_weight_attrs )
set_weight_attrs ( w2_weight_scale1 , extra_weight_attrs )
2024-12-06 15:42:10 +08:00
# INPUT_SCALES
if self . quant_config . activation_scheme == " static " :
if not self . quant_config . is_checkpoint_fp8_serialized :
raise ValueError (
" Found static activation scheme for checkpoint that "
" was not serialized fp8. "
)
w13_input_scale = torch . nn . Parameter (
torch . ones ( num_experts , dtype = torch . float32 ) , requires_grad = False
)
layer . register_parameter ( " w13_input_scale " , w13_input_scale )
set_weight_attrs ( w13_input_scale , extra_weight_attrs )
w2_input_scale = torch . nn . Parameter (
torch . ones ( num_experts , dtype = torch . float32 ) , requires_grad = False
)
layer . register_parameter ( " w2_input_scale " , w2_input_scale )
set_weight_attrs ( w2_input_scale , extra_weight_attrs )
else :
layer . w13_input_scale = None
layer . w2_input_scale = None
def process_weights_after_loading ( self , layer : Module ) - > None :
2025-06-05 23:00:18 -07:00
if _is_hip and _use_hip_int4 :
2025-03-09 21:46:35 -07:00
self . process_weights_hip_int4 ( layer )
2025-03-06 15:33:02 -08:00
return
2024-12-26 00:02:14 +08:00
# Block quant doesn't need to process weights after loading
if self . block_quant :
2024-12-29 03:23:39 -08:00
# If ROCm, normalize the weights and scales to e4m3fnuz
2025-05-08 08:28:24 +08:00
if _is_fp8_fnuz :
2024-12-29 03:23:39 -08:00
# activation_scheme: dynamic
w13_weight , w13_weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
weight = layer . w13_weight ,
weight_scale = layer . w13_weight_scale_inv ,
input_scale = None ,
)
w2_weight , w2_weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
weight = layer . w2_weight ,
weight_scale = layer . w2_weight_scale_inv ,
input_scale = None ,
)
# Reset the parameter
layer . w13_weight = torch . nn . Parameter ( w13_weight , requires_grad = False )
layer . w13_weight_scale_inv = torch . nn . Parameter (
w13_weight_scale , requires_grad = False
)
layer . w13_input_scale = None
layer . w2_weight = torch . nn . Parameter ( w2_weight , requires_grad = False )
layer . w2_weight_scale_inv = torch . nn . Parameter (
w2_weight_scale , requires_grad = False
)
layer . w2_input_scale = None
2025-03-04 03:00:46 -08:00
2025-06-05 23:00:18 -07:00
if _use_aiter :
2025-05-08 08:28:24 +08:00
# Pre-shuffle weights
layer . w13_weight . data = shuffle_weight (
layer . w13_weight . contiguous ( ) , ( 16 , 16 )
)
layer . w2_weight . data = shuffle_weight (
layer . w2_weight . contiguous ( ) , ( 16 , 16 )
)
2025-06-28 10:04:29 +08:00
if _is_cpu :
assert (
_is_cpu_amx_available
) , " Fp8MoEMethod on CPU requires that CPU has AMX support "
2025-07-04 00:51:09 +08:00
_amx_process_weight_after_loading ( layer , [ " w13_weight " , " w2_weight " ] )
2025-06-28 10:04:29 +08:00
2024-12-26 00:02:14 +08:00
return
2025-03-09 21:46:35 -07:00
2024-12-07 05:18:26 -08:00
# If checkpoint is fp16 or bfloat16, quantize in place.
2024-12-06 15:42:10 +08:00
if not self . quant_config . is_checkpoint_fp8_serialized :
2025-05-08 08:28:24 +08:00
# If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
2024-12-06 15:42:10 +08:00
w13_weight = torch . empty_like ( layer . w13_weight . data , dtype = fp8_dtype )
w2_weight = torch . empty_like ( layer . w2_weight . data , dtype = fp8_dtype )
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer . w13_weight_scale = torch . nn . Parameter (
torch . ones (
2025-07-27 01:00:21 -07:00
layer . num_local_experts ,
dtype = torch . float32 ,
device = w13_weight . device ,
2024-12-06 15:42:10 +08:00
) ,
requires_grad = False ,
)
2025-07-27 01:00:21 -07:00
for expert in range ( layer . num_local_experts ) :
2025-04-16 15:26:49 -07:00
w13_weight [ expert , : , : ] , layer . w13_weight_scale [ expert ] = (
scaled_fp8_quant ( layer . w13_weight . data [ expert , : , : ] )
)
w2_weight [ expert , : , : ] , layer . w2_weight_scale [ expert ] = (
scaled_fp8_quant ( layer . w2_weight . data [ expert , : , : ] )
)
2024-12-06 15:42:10 +08:00
layer . w13_weight = torch . nn . Parameter ( w13_weight , requires_grad = False )
layer . w2_weight = torch . nn . Parameter ( w2_weight , requires_grad = False )
2024-12-07 05:18:26 -08:00
2025-03-11 18:12:56 -07:00
if _is_hip :
2025-03-09 21:46:35 -07:00
self . process_weights_hip_scale_padding ( layer )
2024-12-06 15:42:10 +08:00
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else :
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self . quant_config . activation_scheme == " static " :
if layer . w13_input_scale is None or layer . w2_input_scale is None :
raise ValueError (
" QuantConfig has static quantization, but found "
" activation scales are None. "
)
if not all_close_1d ( layer . w13_input_scale ) or not all_close_1d (
layer . w2_input_scale
) :
print_warning_once (
" Found input_scales that are not equal for "
" fp8 MoE layer. Using the maximum across experts "
" for each layer. "
)
layer . w13_input_scale = torch . nn . Parameter (
layer . w13_input_scale . max ( ) , requires_grad = False
)
layer . w2_input_scale = torch . nn . Parameter (
layer . w2_input_scale . max ( ) , requires_grad = False
)
2024-12-07 05:18:26 -08:00
2024-12-06 15:42:10 +08:00
# If ROCm, normalize the weights and scales to e4m3fnuz
2025-05-08 08:28:24 +08:00
if _is_fp8_fnuz :
2024-12-06 15:42:10 +08:00
# Normalize the weights and scales
w13_weight , w13_weight_scale , w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz (
layer . w13_weight , layer . w13_weight_scale , layer . w13_input_scale
)
)
w2_weight , w2_weight_scale , w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz (
layer . w2_weight , layer . w2_weight_scale , layer . w2_input_scale
)
)
# Reset the parameter
layer . w13_weight = torch . nn . Parameter ( w13_weight , requires_grad = False )
layer . w13_weight_scale = torch . nn . Parameter (
w13_weight_scale , requires_grad = False
)
if w13_input_scale is not None :
layer . w13_input_scale = torch . nn . Parameter (
w13_input_scale , requires_grad = False
)
layer . w2_weight = torch . nn . Parameter ( w2_weight , requires_grad = False )
layer . w2_weight_scale = torch . nn . Parameter (
w2_weight_scale , requires_grad = False
)
if w2_input_scale is not None :
layer . w2_input_scale = torch . nn . Parameter (
w2_input_scale , requires_grad = False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer . w13_weight_scale is not None
shard_size = layer . intermediate_size_per_partition
max_w13_scales = layer . w13_weight_scale . max ( dim = 1 ) . values
2025-07-27 01:00:21 -07:00
for expert_id in range ( layer . num_local_experts ) :
2024-12-06 15:42:10 +08:00
start = 0
for shard_id in range ( 2 ) :
dq_weight = per_tensor_dequantize (
layer . w13_weight [ expert_id ] [ start : start + shard_size , : ] ,
layer . w13_weight_scale [ expert_id ] [ shard_id ] ,
)
2025-04-16 15:26:49 -07:00
(
layer . w13_weight [ expert_id ] [ start : start + shard_size , : ] ,
_ ,
) = scaled_fp8_quant ( dq_weight , max_w13_scales [ expert_id ] )
2024-12-06 15:42:10 +08:00
start + = shard_size
layer . w13_weight_scale = torch . nn . Parameter (
max_w13_scales , requires_grad = False
)
2024-12-07 05:18:26 -08:00
2025-03-11 18:12:56 -07:00
if _is_hip :
2025-03-09 21:46:35 -07:00
self . process_weights_hip_scale_padding ( layer )
2024-12-06 15:42:10 +08:00
return
2025-03-09 21:46:35 -07:00
def process_weights_hip_int4 ( self , layer : Module ) :
2025-06-05 23:00:18 -07:00
# TODO: _use_aiter: add after triton kernel added
2025-03-09 21:46:35 -07:00
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer . w13_weight = torch . nn . Parameter (
2025-04-10 18:19:57 -07:00
shuffle_weight ( layer . w13_weight . data , ( 16 , 16 ) ) ,
2025-03-09 21:46:35 -07:00
requires_grad = False ,
)
torch . cuda . empty_cache ( )
layer . w2_weight = torch . nn . Parameter (
2025-04-10 18:19:57 -07:00
shuffle_weight ( layer . w2_weight . data , ( 16 , 16 ) ) ,
2025-03-09 21:46:35 -07:00
requires_grad = False ,
)
torch . cuda . empty_cache ( )
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert layer . w13_weight_scale is not None
shard_size = layer . intermediate_size_per_partition
max_w13_scales = layer . w13_weight_scale . max ( dim = 1 ) . values
2025-07-27 01:00:21 -07:00
for expert_id in range ( layer . num_local_experts ) :
2025-03-09 21:46:35 -07:00
start = 0
max_w13_scale_fp8 = max_w13_scales [ expert_id ]
for shard_id in range ( 2 ) :
if layer . w13_weight_scale [ expert_id ] [ shard_id ] != max_w13_scale_fp8 :
int4_rescale = (
layer . w13_weight_scale [ expert_id ] [ shard_id ] / max_w13_scale_fp8
)
layer . w13_weight_scale1 [ expert_id ] [
start : start + shard_size
] * = int4_rescale
start + = shard_size
layer . w13_weight_scale = torch . nn . Parameter ( max_w13_scales , requires_grad = False )
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
2025-07-27 01:00:21 -07:00
for expert_id in range ( layer . num_local_experts ) :
2025-03-09 21:46:35 -07:00
layer . w13_weight_scale1 [ expert_id ] * = max_w13_scales [ expert_id ]
layer . w2_weight_scale1 [ expert_id ] * = layer . w2_weight_scale [ expert_id ]
2025-04-07 00:34:08 -07:00
def process_weights_hip_scale_padding ( self , layer : Module ) :
2025-03-09 21:46:35 -07:00
from sglang . srt . layers . moe . fused_moe_triton . fused_moe import (
padding_size , # Avoid circular import
)
2025-06-05 23:00:18 -07:00
if _use_aiter :
2025-03-09 21:46:35 -07:00
layer . w13_weight = torch . nn . Parameter (
2025-04-10 18:19:57 -07:00
shuffle_weight ( layer . w13_weight . data , ( 16 , 16 ) ) ,
2025-03-09 21:46:35 -07:00
requires_grad = False ,
)
torch . cuda . empty_cache ( )
layer . w2_weight = torch . nn . Parameter (
2025-04-10 18:19:57 -07:00
shuffle_weight ( layer . w2_weight . data , ( 16 , 16 ) ) ,
2025-03-09 21:46:35 -07:00
requires_grad = False ,
)
torch . cuda . empty_cache ( )
2025-08-22 13:13:45 -07:00
2025-06-05 23:00:18 -07:00
# ROCm (_use_aiter): using column-wise scaling
2025-03-09 21:46:35 -07:00
layer . w13_weight_scale1 * = layer . w13_weight_scale . unsqueeze ( - 1 )
layer . w2_weight_scale1 * = layer . w2_weight_scale . unsqueeze ( - 1 )
2025-04-28 11:01:20 -07:00
elif get_bool_env_var ( " SGLANG_MOE_PADDING " ) :
2025-03-09 21:46:35 -07:00
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer . w13_weight = torch . nn . Parameter (
F . pad ( layer . w13_weight . data , ( 0 , padding_size ) , " constant " , 0 ) ,
requires_grad = False ,
)
torch . cuda . empty_cache ( )
layer . w2_weight = torch . nn . Parameter (
F . pad ( layer . w2_weight . data , ( 0 , padding_size ) , " constant " , 0 ) ,
requires_grad = False ,
)
torch . cuda . empty_cache ( )
2025-09-05 21:09:09 -07:00
def create_moe_runner (
self , layer : torch . nn . Module , moe_runner_config : MoeRunnerConfig
) :
self . moe_runner_config = moe_runner_config
self . runner = MoeRunner ( MoeRunnerBackend . TRITON , moe_runner_config )
2024-12-06 15:42:10 +08:00
def apply (
self ,
layer : torch . nn . Module ,
2025-09-05 21:09:09 -07:00
dispatch_output : DispatchOutput ,
) - > CombineInput :
from sglang . srt . layers . moe . token_dispatcher import StandardCombineInput
x = dispatch_output . hidden_states
topk_output = dispatch_output . topk_output
moe_runner_config = self . moe_runner_config
2024-12-06 15:42:10 +08:00
2025-07-04 00:51:09 +08:00
if use_intel_amx_backend ( layer ) :
2025-07-18 12:43:25 +08:00
from sglang . srt . layers . moe . topk import apply_topk_weights_cpu
2025-07-19 00:51:15 -07:00
topk_weights , topk_ids , _ = topk_output
2025-07-18 12:43:25 +08:00
x , topk_weights = apply_topk_weights_cpu (
2025-08-14 21:14:53 -07:00
moe_runner_config . apply_router_weight_on_input , topk_weights , x
2025-07-18 12:43:25 +08:00
)
2025-09-05 21:09:09 -07:00
output = torch . ops . sgl_kernel . fused_experts_cpu (
2025-06-28 10:04:29 +08:00
x ,
layer . w13_weight ,
layer . w2_weight ,
topk_weights ,
topk_ids ,
False , # inplace See [Note] inplace should be False in fused_experts.
False , # use_int8_w8a8
True , # use_fp8_w8a16
layer . w13_weight_scale_inv , # w1_scale
layer . w2_weight_scale_inv , # w2_scale
self . quant_config . weight_block_size , # block_size
None , # a1_scale
None , # a2_scale
True , # is_vnni
)
2025-09-05 21:09:09 -07:00
return StandardCombineInput ( hidden_states = output )
2025-06-28 10:04:29 +08:00
2025-04-16 15:26:49 -07:00
if _is_hip :
2025-05-08 08:28:24 +08:00
ret = self . maybe_apply_hip_fused_experts (
layer ,
x ,
2025-07-19 00:51:15 -07:00
topk_output ,
2025-08-14 21:14:53 -07:00
moe_runner_config . activation ,
moe_runner_config . no_combine ,
2025-05-08 08:28:24 +08:00
)
if ret is not None :
2025-09-05 21:09:09 -07:00
return StandardCombineInput ( hidden_states = ret )
2025-04-16 15:26:49 -07:00
2025-08-20 15:10:16 -07:00
if self . use_cutlass_fused_experts_fp8 :
2025-06-05 13:13:14 -07:00
from sglang . srt . layers . moe . cutlass_moe import cutlass_fused_experts_fp8
2025-05-16 13:14:07 -07:00
2025-07-19 00:51:15 -07:00
topk_weights , topk_ids , _ = topk_output
2025-08-01 19:00:24 -07:00
output = cutlass_fused_experts_fp8 (
2025-05-16 13:14:07 -07:00
x ,
layer . w13_weight . transpose ( 1 , 2 ) ,
layer . w2_weight . transpose ( 1 , 2 ) ,
layer . w13_weight_scale_inv . transpose ( 1 , 2 ) ,
layer . w2_weight_scale_inv . transpose ( 1 , 2 ) ,
topk_weights ,
topk_ids ,
self . ab_strides1 ,
self . c_strides1 ,
self . ab_strides2 ,
self . c_strides2 ,
self . workspace ,
self . a_ptr ,
self . b_ptr ,
self . out_ptr ,
self . a_scales_ptr ,
self . b_scales_ptr ,
self . expert_offsets ,
self . problem_sizes1 ,
self . problem_sizes2 ,
use_fp8_blockscale = True ,
)
2025-09-05 21:09:09 -07:00
return StandardCombineInput ( hidden_states = output )
quant_info = TritonMoeQuantInfo (
w13_weight = layer . w13_weight ,
w2_weight = layer . w2_weight ,
2025-04-16 15:26:49 -07:00
use_fp8_w8a8 = True ,
2025-09-05 21:09:09 -07:00
w13_scale = (
2025-04-16 15:26:49 -07:00
layer . w13_weight_scale_inv
if self . block_quant
else layer . w13_weight_scale
) ,
w2_scale = (
layer . w2_weight_scale_inv if self . block_quant else layer . w2_weight_scale
) ,
2025-09-05 21:09:09 -07:00
a13_scale = layer . w13_input_scale ,
2025-04-16 15:26:49 -07:00
a2_scale = layer . w2_input_scale ,
block_shape = self . quant_config . weight_block_size ,
)
2025-09-05 21:09:09 -07:00
return self . runner . run ( dispatch_output , quant_info )
2024-12-06 15:42:10 +08:00
2025-07-31 19:56:34 -07:00
def apply_with_router_logits (
self ,
layer : torch . nn . Module ,
2025-09-05 21:09:09 -07:00
dispatch_output : StandardDispatchOutput ,
2025-07-31 19:56:34 -07:00
) - > torch . Tensor :
2025-09-05 21:09:09 -07:00
x = dispatch_output . hidden_states
topk_output = dispatch_output . topk_output
activation = self . moe_runner_config . activation
routed_scaling_factor = self . moe_runner_config . routed_scaling_factor
2025-08-14 21:14:53 -07:00
from flashinfer . fused_moe import trtllm_fp8_block_scale_moe
from sglang . srt . layers . moe . topk import TopKOutputChecker
assert TopKOutputChecker . format_is_bypassed ( topk_output )
router_logits = topk_output . router_logits
topk_config = topk_output . topk_config
2025-07-31 19:56:34 -07:00
assert (
activation == " silu "
) , " Only silu is supported for flashinfer blockscale fp8 moe "
a_q , a_sf = per_token_group_quant_fp8 ( x , self . quant_config . weight_block_size [ 1 ] )
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf . t ( ) . contiguous ( )
2025-08-18 18:03:19 -07:00
assert (
topk_config . num_expert_group is not None
and topk_config . topk_group is not None
) , " Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None "
2025-09-07 20:28:14 -07:00
correction_bias = (
None
if topk_config . correction_bias is None
else topk_config . correction_bias . to ( x . dtype )
)
2025-07-31 19:56:34 -07:00
return trtllm_fp8_block_scale_moe (
routing_logits = router_logits . to ( torch . float32 ) ,
2025-08-18 18:03:19 -07:00
routing_bias = correction_bias ,
2025-07-31 19:56:34 -07:00
hidden_states = a_q ,
hidden_states_scale = a_sf_t ,
gemm1_weights = layer . w13_weight ,
gemm1_weights_scale = layer . w13_weight_scale_inv ,
gemm2_weights = layer . w2_weight ,
gemm2_weights_scale = layer . w2_weight_scale_inv ,
num_experts = layer . num_experts ,
2025-08-14 21:14:53 -07:00
top_k = topk_config . top_k ,
n_group = topk_config . num_expert_group ,
topk_group = topk_config . topk_group ,
2025-07-31 19:56:34 -07:00
intermediate_size = layer . w2_weight . shape [ 2 ] ,
local_expert_offset = layer . moe_ep_rank * layer . num_local_experts ,
local_num_experts = layer . num_local_experts ,
2025-08-18 18:03:19 -07:00
routed_scaling_factor = (
routed_scaling_factor if routed_scaling_factor is not None else 1.0
) ,
2025-07-31 19:56:34 -07:00
tile_tokens_dim = get_tile_tokens_dim (
2025-08-18 18:03:19 -07:00
x . shape [ 0 ] , topk_config . top_k , layer . num_experts
2025-07-31 19:56:34 -07:00
) ,
routing_method_type = 2 , # DeepSeek-styled routing method
use_shuffled_weight = False ,
)
2025-05-08 08:28:24 +08:00
def maybe_apply_hip_fused_experts (
self ,
layer : torch . nn . Module ,
x : torch . Tensor ,
2025-07-19 00:51:15 -07:00
topk_output : TopKOutput ,
2025-05-08 08:28:24 +08:00
activation : str = " silu " ,
no_combine : bool = False ,
) - > Optional [ torch . Tensor ] :
2025-07-19 00:51:15 -07:00
topk_weights , topk_ids , _ = topk_output
2025-06-05 23:00:18 -07:00
if _use_hip_int4 :
# TODO: add triton kernel and add check _use_aiter
2025-05-08 08:28:24 +08:00
assert not no_combine , f " { no_combine =} is not supported. "
2025-06-23 06:33:09 +08:00
return fused_moe (
2025-05-08 08:28:24 +08:00
x ,
layer . w13_weight ,
layer . w2_weight ,
topk_weights ,
topk_ids ,
2025-06-23 06:33:09 +08:00
quant_type = QuantType . per_Token ,
w1_scale = layer . w13_weight_scale1 ,
w2_scale = layer . w2_weight_scale1 ,
2025-05-08 08:28:24 +08:00
activation = (
ActivationType . Silu if activation == " silu " else ActivationType . Gelu
) ,
)
2025-06-05 23:00:18 -07:00
if _use_aiter :
2025-05-08 08:28:24 +08:00
assert not no_combine , f " { no_combine =} is not supported. "
if self . block_quant :
2025-06-17 13:26:51 +08:00
return fused_moe (
2025-05-08 08:28:24 +08:00
x ,
layer . w13_weight ,
layer . w2_weight ,
topk_weights ,
topk_ids ,
2025-06-17 13:26:51 +08:00
w1_scale = layer . w13_weight_scale_inv ,
w2_scale = layer . w2_weight_scale_inv ,
quant_type = QuantType . per_128x128 ,
activation = (
ActivationType . Silu
if activation == " silu "
else ActivationType . Gelu
) ,
2025-05-08 08:28:24 +08:00
expert_mask = None ,
)
else :
2025-06-23 06:33:09 +08:00
return fused_moe (
2025-05-08 08:28:24 +08:00
x ,
layer . w13_weight ,
layer . w2_weight ,
topk_weights ,
topk_ids ,
2025-06-23 06:33:09 +08:00
quant_type = QuantType . per_Token ,
w1_scale = layer . w13_weight_scale1 ,
w2_scale = layer . w2_weight_scale1 ,
2025-05-08 08:28:24 +08:00
activation = (
ActivationType . Silu
if activation == " silu "
else ActivationType . Gelu
) ,
)
return None
2024-12-06 15:42:10 +08:00
class Fp8KVCacheMethod ( BaseKVCacheMethod ) :
"""
Supports loading kv - cache scaling factors from FP8 checkpoints .
"""
def __init__ ( self , quant_config : Fp8Config ) :
super ( ) . __init__ ( quant_config )