Files
sglang/python/sglang/srt/layers/quantization/fp8.py

1238 lines
49 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
MARLIN_FP8_AVAILABLE = True
except ImportError:
MARLIN_FP8_AVAILABLE = False
def dummy_func(*args, **kwargs):
raise ImportError(
"marlin FP8 requires some operators from vllm. Please install vllm."
)
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
from sglang.srt.layers.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import (
fp8_dtype,
is_fp8_fnuz,
per_token_group_quant_fp8,
scaled_fp8_quant,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
can_auto_enable_marlin_fp8,
cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
)
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_cuda,
is_hip,
is_npu,
is_sm90_supported,
is_sm100_supported,
log_info_on_rank0,
next_power_of_2,
print_warning_once,
set_weight_attrs,
use_intel_amx_backend,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
DispatchOutput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_fp8_fnuz = is_fp8_fnuz()
_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip and (_use_aiter or _use_hip_int4):
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
weight_block_size: List[int] = None,
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
log_info_on_rank0(logger, "Detected fp8 checkpoint.")
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or []
if weight_block_size is not None:
if not is_checkpoint_fp8_serialized:
raise ValueError(
f"The block-wise quantization only supports fp8-serialized checkpoint for now."
)
if len(weight_block_size) != 2:
raise ValueError(
f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
)
if activation_scheme != "dynamic":
raise ValueError(
f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
)
self.weight_block_size = weight_block_size
@classmethod
def get_name(cls) -> str:
return "fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = False
if _is_cuda and MARLIN_FP8_AVAILABLE:
force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
auto_enable = can_auto_enable_marlin_fp8()
self.use_marlin = force_marlin or auto_enable
self.block_quant = self.quant_config.weight_block_size is not None
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# Required by row parallel
if tp_size > 1 and input_size // input_size_per_partition == tp_size:
if input_size_per_partition % block_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# Required by column parallel or enabling merged weights
if (
tp_size > 1 and output_size // output_size_per_partition == tp_size
) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized
else params_dtype
)
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if self.block_quant:
if hasattr(self.quant_config, "activation_scheme"):
assert self.quant_config.activation_scheme == "dynamic"
elif hasattr(self.quant_config, "linear_activation_scheme"):
assert self.quant_config.linear_activation_scheme == "dynamic"
scale = BlockQuantScaleParameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale_inv", scale)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
# INPUT ACTIVATION SCALE
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
# activation_scheme: dynamic
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=None,
)
layer.input_scale = None
elif _is_cpu:
assert (
_is_cpu_amx_available
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"])
return
else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
else:
layer.weight = Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
if self.cutlass_fp8_supported or self.use_marlin:
# apply per-channel quantization default as
# cutlass sgl-kernel and marlin only support per-channel scale
qweight, weight_scale = per_token_group_quant_fp8(
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
else:
# per-tensor quantization
qweight, weight_scale = input_to_float8(layer.weight)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
layer.weight_scale = Parameter(
layer.weight_scale.data, requires_grad=False
)
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = Parameter(
layer.input_scale.data, requires_grad=False
)
# cutlass sgl-kernel and marlin only support per-channel scale
if self.cutlass_fp8_supported or self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
weight, weight_scale, input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale,
)
)
if input_scale is not None:
layer.input_scale = Parameter(
input_scale, requires_grad=False
)
weight_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if (
hasattr(self.quant_config, "activation_scheme")
and self.quant_config.activation_scheme == "static"
) or (
hasattr(self.quant_config, "linear_activation_scheme")
and self.quant_config.linear_activation_scheme == "static"
):
layer.input_scale = Parameter(
layer.input_scale.max(), requires_grad=False
)
if self.use_marlin:
if self.block_quant:
layer.weight_block_size = self.quant_config.weight_block_size
prepare_fp8_layer_for_marlin(layer, not self.block_quant)
# Activations not quantized for marlin.
del layer.input_scale
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
if self.block_quant:
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
x,
layer.weight,
layer.weight_scale_inv,
self.quant_config.weight_block_size,
bias,
x.dtype,
True, # is_vnni
)
return self.w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale_inv,
input_scale=None,
bias=bias,
)
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=False,
)
def get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.use_cutlass_fused_experts_fp8 = (
get_bool_env_var("SGLANG_CUTLASS_MOE")
and self.cutlass_fp8_supported
and self.block_quant
and (is_sm100_supported() or is_sm90_supported())
)
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size_per_partition % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
if _is_hip and _use_hip_int4:
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // 8,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 8,
dtype=params_dtype,
),
requires_grad=False,
)
else:
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.block_quant:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
if self.use_cutlass_fused_experts_fp8:
self.ab_strides1 = torch.full(
(num_experts,),
hidden_size,
device=w13_weight.device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(num_experts,),
2 * intermediate_size_per_partition,
device=w13_weight.device,
dtype=torch.int64,
)
self.ab_strides2 = torch.full(
(num_experts,),
intermediate_size_per_partition,
device=w2_weight.device,
dtype=torch.int64,
)
self.c_strides2 = torch.full(
(num_experts,),
hidden_size,
device=w2_weight.device,
dtype=torch.int64,
)
self.workspace = torch.empty(
90000, device=w13_weight.device, dtype=torch.uint8
)
self.a_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.b_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.out_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.a_scales_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.b_scales_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.expert_offsets = torch.empty(
num_experts + 1, device=w13_weight.device, dtype=torch.int32
)
self.problem_sizes1 = torch.empty(
num_experts, 3, device=w13_weight.device, dtype=torch.int32
)
self.problem_sizes2 = torch.empty(
num_experts, 3, device=w13_weight.device, dtype=torch.int32
)
else:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale1", w13_weight_scale1)
layer.register_parameter("w2_weight_scale1", w2_weight_scale1)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if _is_hip and _use_hip_int4:
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale1, extra_weight_attrs)
set_weight_attrs(w2_weight_scale1, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and _use_hip_int4:
self.process_weights_hip_int4(layer)
return
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
weight_scale=layer.w13_weight_scale_inv,
input_scale=None,
)
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w2_weight,
weight_scale=layer.w2_weight_scale_inv,
input_scale=None,
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w13_input_scale = None
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
layer.w2_input_scale = None
if _use_aiter:
# Pre-shuffle weights
layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16)
)
layer.w2_weight.data = shuffle_weight(
layer.w2_weight.contiguous(), (16, 16)
)
if _is_cpu:
assert (
_is_cpu_amx_available
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(
torch.ones(
layer.num_local_experts,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.num_local_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if _is_hip:
self.process_weights_hip_scale_padding(layer)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.quant_config.activation_scheme == "static":
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
)
)
w2_weight, w2_weight_scale, w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
)
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_local_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
if _is_hip:
self.process_weights_hip_scale_padding(layer)
return
def process_weights_hip_int4(self, layer: Module):
# TODO: _use_aiter: add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_local_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
)
layer.w13_weight_scale1[expert_id][
start : start + shard_size
] *= int4_rescale
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for expert_id in range(layer.num_local_experts):
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
def process_weights_hip_scale_padding(self, layer: Module):
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
padding_size, # Avoid circular import
)
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
# ROCm (_use_aiter): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("SGLANG_MOE_PADDING"):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
dispatch_output: DispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
moe_runner_config = self.moe_runner_config
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
output = torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
False, # inplace See [Note] inplace should be False in fused_experts.
False, # use_int8_w8a8
True, # use_fp8_w8a16
layer.w13_weight_scale_inv, # w1_scale
layer.w2_weight_scale_inv, # w2_scale
self.quant_config.weight_block_size, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)
return StandardCombineInput(hidden_states=output)
if _is_hip:
ret = self.maybe_apply_hip_fused_experts(
layer,
x,
topk_output,
moe_runner_config.activation,
moe_runner_config.no_combine,
)
if ret is not None:
return StandardCombineInput(hidden_states=ret)
if self.use_cutlass_fused_experts_fp8:
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
topk_weights, topk_ids, _ = topk_output
output = cutlass_fused_experts_fp8(
x,
layer.w13_weight.transpose(1, 2),
layer.w2_weight.transpose(1, 2),
layer.w13_weight_scale_inv.transpose(1, 2),
layer.w2_weight_scale_inv.transpose(1, 2),
topk_weights,
topk_ids,
self.ab_strides1,
self.c_strides1,
self.ab_strides2,
self.c_strides2,
self.workspace,
self.a_ptr,
self.b_ptr,
self.out_ptr,
self.a_scales_ptr,
self.b_scales_ptr,
self.expert_offsets,
self.problem_sizes1,
self.problem_sizes2,
use_fp8_blockscale=True,
)
return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
w13_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
return self.runner.run(dispatch_output, quant_info)
def apply_with_router_logits(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
activation = self.moe_runner_config.activation
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_bypassed(topk_output)
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
assert (
activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
assert (
topk_config.num_expert_group is not None
and topk_config.topk_group is not None
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(x.dtype)
)
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
gemm1_weights=layer.w13_weight,
gemm1_weights_scale=layer.w13_weight_scale_inv,
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
routed_scaling_factor=(
routed_scaling_factor if routed_scaling_factor is not None else 1.0
),
tile_tokens_dim=get_tile_tokens_dim(
x.shape[0], topk_config.top_k, layer.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
use_shuffled_weight=False,
)
def maybe_apply_hip_fused_experts(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
activation: str = "silu",
no_combine: bool = False,
) -> Optional[torch.Tensor]:
topk_weights, topk_ids, _ = topk_output
if _use_hip_int4:
# TODO: add triton kernel and add check _use_aiter
assert not no_combine, f"{no_combine=} is not supported."
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_Token,
w1_scale=layer.w13_weight_scale1,
w2_scale=layer.w2_weight_scale1,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if _use_aiter:
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
w1_scale=layer.w13_weight_scale_inv,
w2_scale=layer.w2_weight_scale_inv,
quant_type=QuantType.per_128x128,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
expert_mask=None,
)
else:
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_Token,
w1_scale=layer.w13_weight_scale1,
w2_scale=layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
return None
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)