Files
sglang/python/sglang/srt/layers/quantization/w4afp8.py

343 lines
12 KiB
Python

from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
class W4AFp8Config(QuantizationConfig):
"""Config class for MIXED_PRECISION W4AFp8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = True,
is_checkpoint_w4afp8_serialized: bool = True,
linear_activation_scheme: str = "dynamic",
moe_activation_scheme: str = "static",
ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
group_size: int = 128,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
if is_checkpoint_w4afp8_serialized:
logger.warning("Detected w4afp8 checkpoint. Please note that")
if moe_activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
self.linear_activation_scheme = linear_activation_scheme
self.moe_activation_scheme = moe_activation_scheme
self.ignored_layers = ignored_layers or []
self.weight_block_size = [128, 128]
self.group_size = group_size
@classmethod
def get_name(cls) -> str:
return "w4afp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.float8_e4m3fn]
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
linear_activation_scheme = "dynamic"
moe_activation_scheme = "static"
weight_block_size = [128, 128]
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
linear_activation_scheme=linear_activation_scheme,
moe_activation_scheme=moe_activation_scheme,
weight_block_size=weight_block_size,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.managers.schedule_batch import global_server_args_dict
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W4AFp8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def interleave_scales(scales: torch.Tensor) -> torch.Tensor:
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
s_shape = scales.shape
# Reshape to separate groups of 4
alignment = 4 if s_shape[2] % 4 == 0 else 1
scales_interleaved = scales.reshape(
s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment
)
# Permute dimensions to interleave
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
# Reshape back to original dimensions but with interleaved values
scales_interleaved = scales_interleaved.reshape(
s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment
)
return scales_interleaved.contiguous()
class W4AFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: W4AFp8Config):
self.quant_config = quant_config
def create_weights(
self,
layer: EPMoE,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
assert "weight_loader" in extra_weight_attrs
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition * 2,
hidden_size // 2,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // 2,
dtype=torch.int8,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // self.quant_config.group_size,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // self.quant_config.group_size,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# Input scales
w13_input_scale = torch.nn.Parameter(
torch.ones((num_experts, 2), dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
# Pre-populate the strides
device = layer.w13_weight.device
self.a_strides1 = torch.full(
(num_experts, 3),
hidden_size,
device=device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(num_experts, 3),
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.a_strides2 = torch.full(
(num_experts, 3),
intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.c_strides2 = torch.full(
(num_experts, 3),
hidden_size,
device=device,
dtype=torch.int64,
)
self.b_strides1 = self.a_strides1
self.s_strides13 = self.c_strides1
self.b_strides2 = self.a_strides2
self.s_strides2 = self.c_strides2
self.expert_offsets = torch.empty(
(num_experts + 1), dtype=torch.int32, device=device
)
self.problem_sizes1 = torch.empty(
(num_experts, 3), dtype=torch.int32, device=device
)
self.problem_sizes2 = torch.empty(
(num_experts, 3), dtype=torch.int32, device=device
)
return
def process_weights_after_loading(self, layer: Module) -> None:
dtype = torch.bfloat16
device = layer.w2_weight.device
# Interleave w13_weight_scale (gate_up_proj)
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
w13_weight_scale = interleave_scales(w13_weight_scale)
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
# Interleave w2_weight_scale (down_proj)
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
w2_weight_scale = interleave_scales(w2_weight_scale)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
# Process input scales
w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
new_w13_input_scale = torch.tensor(
[w13_input_scale_max],
dtype=dtype,
device=device,
)
layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
new_w2_input_scale = torch.tensor(
[w2_input_scale_max], dtype=dtype, device=device
)
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: EPMoE,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
local_topk_ids = topk_ids
if get_moe_expert_parallel_world_size() > 1:
local_topk_ids = torch.where(
topk_ids == -1,
layer.num_experts,
topk_ids,
)
output = cutlass_w4a8_moe(
layer.start_expert_id,
layer.end_expert_id,
layer.num_experts,
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale_inv,
layer.w2_weight_scale_inv,
topk_weights,
topk_ids,
local_topk_ids,
self.a_strides1,
self.b_strides1,
self.c_strides1,
self.a_strides2,
self.b_strides2,
self.c_strides2,
self.s_strides13,
self.s_strides2,
self.expert_offsets,
self.problem_sizes1,
self.problem_sizes2,
layer.w13_input_scale,
layer.w2_input_scale,
)
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return StandardCombineInput(hidden_states=output)