Refactor: move all quantization-related code to srt/layer/quantization (#7989)
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
@@ -17,7 +17,6 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.parameter import (
|
||||
BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
@@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import (
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,7 +55,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"IPEXAWQLinearMethod",
|
||||
]
|
||||
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_npu = is_npu()
|
||||
|
||||
@@ -110,91 +105,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Create weights for a linear layer.
|
||||
The weights will be set as attributes of the layer.
|
||||
|
||||
Args:
|
||||
layer: The layer that is using the LinearMethodBase factory.
|
||||
input_size_per_partition: Size of the weight input dim on rank X.
|
||||
output_partition_sizes: Sizes of the output dim of each logical
|
||||
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
||||
is a list contains the width of Wq, Wk, Wv on rank X.
|
||||
input_size: Size of the input dim of the weight across all ranks.
|
||||
output_size: Size of the output dim of the weight across all ranks.
|
||||
params_dtype: Datatype of the parameters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"""Linear method without quantization."""
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
_amx_process_weight_after_loading(layer, ["weight"])
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||
x, layer.weight, bias, True # is_vnni
|
||||
)
|
||||
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
|
||||
class LinearBase(torch.nn.Module):
|
||||
"""Base linear layer.
|
||||
|
||||
@@ -310,7 +220,7 @@ class ReplicatedLinear(LinearBase):
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
@@ -845,7 +755,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: Optional["QuantizationConfig"] = None,
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
|
||||
@@ -27,22 +27,20 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
silu_and_mul_triton_kernel,
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
scaled_fp8_quant,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
sglang_per_token_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -53,7 +51,6 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -904,324 +901,6 @@ class EPMoE(torch.nn.Module):
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
|
||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# scale
|
||||
layer.register_parameter("w13_input_scale", None)
|
||||
layer.register_parameter("w13_weight_scale", None)
|
||||
|
||||
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.block_quant:
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
)
|
||||
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}."
|
||||
)
|
||||
if tp_size > 1:
|
||||
# Required by row parallel
|
||||
if intermediate_size % block_k != 0:
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}."
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if self.block_quant:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts_per_partition,
|
||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts_per_partition,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
else:
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||
if self.block_quant
|
||||
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8."
|
||||
)
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
layer.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
for expert in range(layer.num_experts_per_partition):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
else:
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.max(layer.w13_weight_scale, dim=1).values,
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_fp8_fnuz:
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
weight_scale=layer.w13_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w2_weight,
|
||||
weight_scale=layer.w2_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
w13_weight, requires_grad=False
|
||||
)
|
||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = None
|
||||
if _use_aiter:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DeepEPMoE(EPMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||
|
||||
@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
@@ -31,11 +30,9 @@ def get_config() -> Optional[Dict[str, Any]]:
|
||||
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEMethodBase",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"override_config",
|
||||
"get_config",
|
||||
"fused_moe",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"moe_align_block_size",
|
||||
|
||||
@@ -1,60 +1,28 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||
|
||||
import importlib
|
||||
from abc import abstractmethod
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_hip,
|
||||
set_weight_attrs,
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
if has_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
|
||||
import logging
|
||||
from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _use_aiter:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -66,333 +34,6 @@ class FusedMoeWeightScaleSupported(Enum):
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, use_triton_kernels: bool = False):
|
||||
super().__init__()
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
||||
if self.use_triton_kernels:
|
||||
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight_n, w2_weight_k = (
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
)
|
||||
if self.use_triton_kernels:
|
||||
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _use_aiter:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Pack weight for get better performance on CPU
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_triton_kernels:
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
if apply_router_weight_on_input:
|
||||
assert (
|
||||
topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
x = x * topk_weights.to(x.dtype)
|
||||
topk_weights = torch.ones_like(
|
||||
topk_weights, dtype=torch.float32
|
||||
) # topk_weights must be FP32 (float32)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
else:
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", f"activation = {activation} is not supported."
|
||||
|
||||
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False, # inplace # See [Note] inplace should be False in fused_experts.
|
||||
False, # use_int8_w8a8
|
||||
False, # use_fp8_w8a16
|
||||
None, # w1_scale
|
||||
None, # w2_scale
|
||||
None, # block_size
|
||||
None, # a1_scale
|
||||
None, # a2_scale
|
||||
True, # is_vnni
|
||||
)
|
||||
else:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
inplace,
|
||||
no_combine,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
inplace,
|
||||
no_combine,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
forward_native = forward_cpu
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
"""FusedMoE layer for MoE models.
|
||||
|
||||
@@ -553,7 +194,7 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
# Load grouped weight scales for group quantization
|
||||
@@ -580,7 +221,7 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
# for per channel weight quantization
|
||||
@@ -600,7 +241,7 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
|
||||
@@ -645,7 +286,7 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
"""Load w2 weights for down projection.
|
||||
@@ -717,7 +358,7 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_id: str,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
):
|
||||
|
||||
|
||||
@@ -19,15 +19,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.eplb import expert_location_dispatch
|
||||
from sglang.srt.eplb.expert_distribution import (
|
||||
ExpertDistributionRecorder,
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location_dispatch import (
|
||||
ExpertLocationDispatchInfo,
|
||||
topk_ids_logical_to_physical,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||
import builtins
|
||||
import inspect
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Dict, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
@@ -45,7 +43,6 @@ except ImportError:
|
||||
) = QQQConfig = Int8TpuConfig = DummyConfig
|
||||
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
@@ -66,6 +63,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
get_dynamic_override,
|
||||
get_linear_quant_method,
|
||||
)
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
@@ -120,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: QuantizationConfig, prefix: str):
|
||||
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
||||
if isinstance(weight_bits, int):
|
||||
config.weight_bits = weight_bits
|
||||
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
||||
if isinstance(group_size, int):
|
||||
config.group_size = group_size
|
||||
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
config.is_sym = is_sym
|
||||
|
||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported quantization config: "
|
||||
f"bits={config.weight_bits}, sym={config.is_sym}"
|
||||
)
|
||||
|
||||
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
||||
elif config.get_name() == "gptq":
|
||||
if config.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {config.weight_bits} bits."
|
||||
)
|
||||
|
||||
|
||||
def get_dynamic_override(
|
||||
config: QuantizationConfig,
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool, None] = None,
|
||||
) -> Union[Dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||
return False
|
||||
# Positive match: matched modules have quant properties overrides
|
||||
# base quant config
|
||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||
if key is None:
|
||||
return pattern_dict
|
||||
else:
|
||||
return pattern_dict.get(key, default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
# Move import here to avoid circular import. This is only used in monkey patching
|
||||
# of vllm's QuantizationConfig.
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
)
|
||||
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if (
|
||||
get_dynamic_override( # noqa: E712
|
||||
cloned_config, layer_name=prefix # noqa: E712
|
||||
)
|
||||
== False
|
||||
): # noqa: E712
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return linear_method_cls(cloned_config)
|
||||
return None
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -81,7 +82,7 @@ class AWQConfig(QuantizationConfig):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
def from_config(cls, config: Dict[str, Any]) -> AWQConfig:
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
@@ -92,7 +93,8 @@ class AWQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["LinearMethodBase"]:
|
||||
) -> Optional[LinearMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
|
||||
@@ -18,14 +18,14 @@ class QuantizeMethodBase(ABC):
|
||||
"""Create weights for a layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
@@ -35,6 +35,74 @@ class QuantizeMethodBase(ABC):
|
||||
return
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Create weights for a linear layer.
|
||||
The weights will be set as attributes of the layer.
|
||||
|
||||
Args:
|
||||
layer: The layer that is using the LinearMethodBase factory.
|
||||
input_size_per_partition: Size of the weight input dim on rank X.
|
||||
output_partition_sizes: Sizes of the output dim of each logical
|
||||
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
||||
is a list contains the width of Wq, Wk, Wv on rank X.
|
||||
input_size: Size of the input dim of the weight across all ranks.
|
||||
output_size: Size of the output dim of the weight across all ranks.
|
||||
params_dtype: Datatype of the parameters.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
@@ -46,12 +114,12 @@ class QuantizationConfig(ABC):
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -62,19 +130,19 @@ class QuantizationConfig(ABC):
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
@@ -117,7 +185,7 @@ class QuantizationConfig(ABC):
|
||||
The quantize method. None if the given layer doesn't support quant
|
||||
method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@@ -125,7 +193,7 @@ class QuantizationConfig(ABC):
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
@@ -7,17 +9,15 @@ import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
@@ -78,7 +78,7 @@ class BlockInt8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> BlockInt8Config:
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_int8_serialized = "int8" in quant_method
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
@@ -93,7 +93,8 @@ class BlockInt8Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -230,7 +231,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class BlockInt8MoEMethod:
|
||||
class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for INT8.
|
||||
Supports loading INT8 checkpoints with static weight scale and
|
||||
dynamic activation scale.
|
||||
@@ -242,25 +243,7 @@ class BlockInt8MoEMethod:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
def __init__(self, quant_config: BlockInt8Config):
|
||||
self.quant_config = quant_config
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
assert self.quant_config.is_checkpoint_int8_serialized
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -40,6 +37,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
is_activation_quantization_format,
|
||||
should_ignore_layer,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
|
||||
try:
|
||||
import vllm
|
||||
@@ -97,7 +95,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self.config = config
|
||||
self.packed_modules_mapping = packed_modules_mapping
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
def get_linear_method(self) -> CompressedTensorsLinearMethod:
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
@@ -117,7 +115,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
@@ -138,7 +137,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
|
||||
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
@@ -357,7 +356,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self, weight_quant: BaseModel, input_quant: BaseModel
|
||||
) -> "CompressedTensorsScheme":
|
||||
) -> CompressedTensorsScheme:
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
@@ -435,7 +434,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
) -> Optional[CompressedTensorsScheme]:
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -28,17 +30,14 @@ except ImportError:
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -56,6 +55,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
convert_to_channelwise,
|
||||
@@ -77,6 +77,9 @@ from sglang.srt.utils import (
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
@@ -152,7 +155,7 @@ class Fp8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
@@ -167,7 +170,8 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -200,7 +204,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
|
||||
def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
@@ -486,7 +490,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class Fp8MoEMethod:
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
@@ -499,25 +503,7 @@ class Fp8MoEMethod:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
@@ -1169,6 +1155,254 @@ class Fp8MoEMethod:
|
||||
return None
|
||||
|
||||
|
||||
class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.block_quant:
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
)
|
||||
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}."
|
||||
)
|
||||
if tp_size > 1:
|
||||
# Required by row parallel
|
||||
if intermediate_size % block_k != 0:
|
||||
raise ValueError(
|
||||
f"The input_size of down's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}."
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if self.block_quant:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts_per_partition,
|
||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts_per_partition,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
else:
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||
if self.block_quant
|
||||
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8."
|
||||
)
|
||||
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
layer.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=w13_weight.device,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
for expert in range(layer.num_experts_per_partition):
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
else:
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
torch.max(layer.w13_weight_scale, dim=1).values,
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_fp8_fnuz:
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
weight_scale=layer.w13_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w2_weight,
|
||||
weight_scale=layer.w2_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
w13_weight, requires_grad=False
|
||||
)
|
||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = None
|
||||
if _use_aiter:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
@@ -5,7 +7,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs
|
||||
from sglang.srt.layers.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
@@ -16,6 +17,8 @@ from sglang.srt.layers.parameter import (
|
||||
permute_param_layout_,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -34,7 +37,11 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
||||
verify_marlin_supported,
|
||||
)
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||
from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
get_linear_quant_method,
|
||||
replace_parameter,
|
||||
unpack_cols,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -49,8 +56,6 @@ if _is_cuda:
|
||||
from sgl_kernel import fused_marlin_moe
|
||||
|
||||
|
||||
FusedMoEMethodBase = QuantizeMethodBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -179,7 +184,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
@@ -191,10 +196,10 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["LinearMethodBase"]:
|
||||
) -> Optional[LinearMethodBase]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
@@ -303,7 +308,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
def from_config(cls, config: Dict[str, Any]) -> GPTQMarlinConfig:
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
@@ -354,7 +359,6 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
@@ -832,6 +836,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.linear import set_weight_attrs
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
||||
|
||||
@@ -1,25 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
except ImportError:
|
||||
@@ -617,7 +623,10 @@ class MarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["MarlinLinearMethod"]:
|
||||
) -> Optional[MarlinLinearMethod]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
@@ -6,14 +7,11 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
is_sm100_supported,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
convert_to_channelwise,
|
||||
is_layer_skipped,
|
||||
@@ -86,7 +85,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
||||
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
||||
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
||||
"kv_cache_quant_algo"
|
||||
@@ -109,7 +108,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if self.exclude_modules and any(
|
||||
module in prefix
|
||||
or (
|
||||
@@ -125,9 +128,6 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
|
||||
# Add MoE support
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return ModelOptFp8MoEMethod(self)
|
||||
|
||||
@@ -246,7 +246,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
super().__init__(quant_config)
|
||||
|
||||
|
||||
class ModelOptFp8MoEMethod:
|
||||
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for ModelOpt FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
||||
|
||||
@@ -254,30 +254,6 @@ class ModelOptFp8MoEMethod:
|
||||
quant_config: The ModelOpt quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
Dynamic class composition pattern.
|
||||
|
||||
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
|
||||
at runtime while avoiding circular import issues.
|
||||
"""
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
@@ -514,7 +490,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
return ["hf_quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||
quant_method = quant_config["quant_algo"]
|
||||
if not quant_method in ["FP8", "NVFP4"]:
|
||||
@@ -559,7 +535,8 @@ class ModelOptFp4Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -740,31 +717,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoEMethod:
|
||||
class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
"""
|
||||
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
|
||||
Args:
|
||||
quant_config: NVFP4 Quant Config
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config: ModelOptFp4Config):
|
||||
self.quant_config = quant_config
|
||||
if not is_sm100_supported():
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
@@ -7,13 +8,14 @@ import torch
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -118,7 +120,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> MoeWNA16Config:
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
@@ -177,8 +179,9 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
# avoid circular import
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||
@@ -209,32 +212,13 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
class MoeWNA16Method:
|
||||
class MoeWNA16Method(FusedMoEMethodBase):
|
||||
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# avoid circular import
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config: MoeWNA16Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.linear import LinearMethodBase
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -71,7 +72,7 @@ class QoQConfig(QuantizationConfig):
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
def get_name(cls) -> str:
|
||||
return "qoq"
|
||||
|
||||
@classmethod
|
||||
@@ -83,7 +84,7 @@ class QoQConfig(QuantizationConfig):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QoQConfig":
|
||||
def from_config(cls, config: Dict[str, Any]) -> QoQConfig:
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
@@ -92,7 +93,7 @@ class QoQConfig(QuantizationConfig):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
|
||||
515
python/sglang/srt/layers/quantization/unquant.py
Normal file
515
python/sglang/srt/layers/quantization/unquant.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import importlib
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_hip,
|
||||
set_weight_attrs,
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_hip = is_hip()
|
||||
_is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _use_aiter:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"""Linear method without quantization."""
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
_amx_process_weight_after_loading(layer, ["weight"])
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||
x, layer.weight, bias, True # is_vnni
|
||||
)
|
||||
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, use_triton_kernels: bool = False):
|
||||
super().__init__()
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
if has_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
else:
|
||||
triton_kernel_moe_forward = None
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
triton_kernel_moe_forward = None
|
||||
|
||||
self.moe_forward_native = moe_forward_native
|
||||
self.fused_experts = fused_experts
|
||||
self.triton_kernel_moe_forward = triton_kernel_moe_forward
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
||||
if self.use_triton_kernels:
|
||||
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight_n, w2_weight_k = (
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
)
|
||||
if self.use_triton_kernels:
|
||||
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _use_aiter:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Pack weight for get better performance on CPU
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_triton_kernels:
|
||||
return self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
if apply_router_weight_on_input:
|
||||
assert (
|
||||
topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
x = x * topk_weights.to(x.dtype)
|
||||
topk_weights = torch.ones_like(
|
||||
topk_weights, dtype=torch.float32
|
||||
) # topk_weights must be FP32 (float32)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
else:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", f"activation = {activation} is not supported."
|
||||
|
||||
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
||||
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False, # inplace # See [Note] inplace should be False in fused_experts.
|
||||
False, # use_int8_w8a8
|
||||
False, # use_fp8_w8a16
|
||||
None, # w1_scale
|
||||
None, # w2_scale
|
||||
None, # block_size
|
||||
None, # a1_scale
|
||||
None, # a2_scale
|
||||
True, # is_vnni
|
||||
)
|
||||
else:
|
||||
return self.moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
inplace,
|
||||
no_combine,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
inplace,
|
||||
no_combine,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
forward_native = forward_cpu
|
||||
|
||||
|
||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts_per_partition: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts_per_partition,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# scale
|
||||
layer.register_parameter("w13_input_scale", None)
|
||||
layer.register_parameter("w13_weight_scale", None)
|
||||
|
||||
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
ones_tensor,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@@ -1,7 +1,11 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -10,6 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -147,6 +154,94 @@ def replace_parameter(
|
||||
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: QuantizationConfig, prefix: str):
|
||||
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
||||
if isinstance(weight_bits, int):
|
||||
config.weight_bits = weight_bits
|
||||
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
||||
if isinstance(group_size, int):
|
||||
config.group_size = group_size
|
||||
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
config.is_sym = is_sym
|
||||
|
||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported quantization config: "
|
||||
f"bits={config.weight_bits}, sym={config.is_sym}"
|
||||
)
|
||||
|
||||
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
||||
elif config.get_name() == "gptq":
|
||||
if config.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {config.weight_bits} bits."
|
||||
)
|
||||
|
||||
|
||||
def get_dynamic_override(
|
||||
config: QuantizationConfig,
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool, None] = None,
|
||||
) -> Union[Dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||
return False
|
||||
# Positive match: matched modules have quant properties overrides
|
||||
# base quant config
|
||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||
if key is None:
|
||||
return pattern_dict
|
||||
else:
|
||||
return pattern_dict.get(key, default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.unquant import (
|
||||
UnquantizedEmbeddingMethod,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
)
|
||||
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if get_dynamic_override(cloned_config, layer_name=prefix) is False:
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return linear_method_cls(cloned_config)
|
||||
return None
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -5,12 +7,13 @@ import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
@@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
||||
@@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
@@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
|
||||
class W4AFp8MoEMethod:
|
||||
class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: W4AFp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.linear import LinearMethodBase
|
||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
@@ -64,7 +67,7 @@ class W8A8Fp8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> W8A8Fp8Config:
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = (
|
||||
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
|
||||
@@ -75,7 +78,7 @@ class W8A8Fp8Config(QuantizationConfig):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
@@ -183,7 +186,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class W8A8FP8MoEMethod:
|
||||
class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
@@ -194,25 +197,7 @@ class W8A8FP8MoEMethod:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
def __init__(self, quant_config: W8A8Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from types import MappingProxyType
|
||||
@@ -11,21 +13,19 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearMethodBase,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.utils import (
|
||||
apply_module_patch,
|
||||
cpu_has_amx_support,
|
||||
@@ -229,14 +229,14 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
|
||||
return cls(config)
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
@@ -374,7 +374,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
|
||||
class W8A8Int8MoEMethod:
|
||||
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for INT8.
|
||||
Supports loading INT8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
@@ -385,25 +385,7 @@ class W8A8Int8MoEMethod:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
def __init__(self, quant_config: W8A8Int8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
@@ -885,13 +867,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.linear import RowParallelLinear
|
||||
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
return self.quant_method.apply(layer, x, bias)
|
||||
|
||||
|
||||
class NPU_W8A8MoEMethod:
|
||||
class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for NPU quantization.
|
||||
|
||||
This class search for specific quantization
|
||||
|
||||
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
method_has_implemented_embedding,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
Reference in New Issue
Block a user