Refactor: move all quantization-related code to srt/layer/quantization (#7989)
This commit is contained in:
@@ -1,12 +1,12 @@
|
|||||||
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
@@ -17,7 +17,6 @@ from sglang.srt.distributed import (
|
|||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
@@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import (
|
|||||||
RowvLLMParameter,
|
RowvLLMParameter,
|
||||||
_ColumnvLLMParameter,
|
_ColumnvLLMParameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
QuantizationConfig,
|
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
||||||
QuantizeMethodBase,
|
|
||||||
)
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
cpu_has_amx_support,
|
QuantizationConfig,
|
||||||
is_cpu,
|
QuantizeMethodBase,
|
||||||
is_npu,
|
)
|
||||||
set_weight_attrs,
|
|
||||||
use_intel_amx_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -59,7 +55,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|||||||
"IPEXAWQLinearMethod",
|
"IPEXAWQLinearMethod",
|
||||||
]
|
]
|
||||||
|
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
|
|
||||||
@@ -110,91 +105,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|||||||
return param[shard_id], loaded_weight
|
return param[shard_id], loaded_weight
|
||||||
|
|
||||||
|
|
||||||
class LinearMethodBase(QuantizeMethodBase):
|
|
||||||
"""Base class for different (maybe quantized) linear methods."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
"""Create weights for a linear layer.
|
|
||||||
The weights will be set as attributes of the layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer: The layer that is using the LinearMethodBase factory.
|
|
||||||
input_size_per_partition: Size of the weight input dim on rank X.
|
|
||||||
output_partition_sizes: Sizes of the output dim of each logical
|
|
||||||
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
|
||||||
is a list contains the width of Wq, Wk, Wv on rank X.
|
|
||||||
input_size: Size of the input dim of the weight across all ranks.
|
|
||||||
output_size: Size of the output dim of the weight across all ranks.
|
|
||||||
params_dtype: Datatype of the parameters.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Apply the weights in layer to the input tensor.
|
|
||||||
Expects create_weights to have been called before on the layer."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedLinearMethod(LinearMethodBase):
|
|
||||||
"""Linear method without quantization."""
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
weight = Parameter(
|
|
||||||
torch.empty(
|
|
||||||
sum(output_partition_sizes),
|
|
||||||
input_size_per_partition,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
set_weight_attrs(weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
if _is_cpu and _is_cpu_amx_available:
|
|
||||||
_amx_process_weight_after_loading(layer, ["weight"])
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
if use_intel_amx_backend(layer):
|
|
||||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
|
||||||
x, layer.weight, bias, True # is_vnni
|
|
||||||
)
|
|
||||||
|
|
||||||
return F.linear(x, layer.weight, bias)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearBase(torch.nn.Module):
|
class LinearBase(torch.nn.Module):
|
||||||
"""Base linear layer.
|
"""Base linear layer.
|
||||||
|
|
||||||
@@ -310,7 +220,7 @@ class ReplicatedLinear(LinearBase):
|
|||||||
assert param.size() == loaded_weight.size()
|
assert param.size() == loaded_weight.size()
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output = self.quant_method.apply(self, x, bias)
|
output = self.quant_method.apply(self, x, bias)
|
||||||
@@ -845,7 +755,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional["QuantizationConfig"] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
tp_rank: Optional[int] = None,
|
tp_rank: Optional[int] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
|
|||||||
@@ -27,22 +27,20 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
silu_and_mul_triton_kernel,
|
silu_and_mul_triton_kernel,
|
||||||
tma_align_input_scale,
|
tma_align_input_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
is_fp8_fnuz,
|
is_fp8_fnuz,
|
||||||
scaled_fp8_quant,
|
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
sglang_per_token_quant_fp8,
|
sglang_per_token_quant_fp8,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
|
||||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -53,7 +51,6 @@ from sglang.srt.utils import (
|
|||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
is_hip,
|
is_hip,
|
||||||
is_npu,
|
is_npu,
|
||||||
set_weight_attrs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -904,324 +901,6 @@ class EPMoE(torch.nn.Module):
|
|||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
num_experts_per_partition: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
# Fused gate_up_proj (column parallel)
|
|
||||||
w13_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts_per_partition,
|
|
||||||
2 * intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# down_proj (row parallel)
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts_per_partition,
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# scale
|
|
||||||
layer.register_parameter("w13_input_scale", None)
|
|
||||||
layer.register_parameter("w13_weight_scale", None)
|
|
||||||
|
|
||||||
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
|
||||||
|
|
||||||
w2_input_scale = torch.nn.Parameter(
|
|
||||||
ones_tensor,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
||||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
|
||||||
ones_tensor,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
||||||
"""MoE method for FP8.
|
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
|
||||||
dynamic/static activation scale.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
quant_config: The quantization config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: Module,
|
|
||||||
num_experts_per_partition: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
params_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
if self.block_quant:
|
|
||||||
block_n, block_k = (
|
|
||||||
self.quant_config.weight_block_size[0],
|
|
||||||
self.quant_config.weight_block_size[1],
|
|
||||||
)
|
|
||||||
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
|
||||||
# Required by column parallel or enabling merged weights
|
|
||||||
if intermediate_size % block_n != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The output_size of gate's and up's weight = "
|
|
||||||
f"{intermediate_size} is not divisible by "
|
|
||||||
f"weight quantization block_n = {block_n}."
|
|
||||||
)
|
|
||||||
if tp_size > 1:
|
|
||||||
# Required by row parallel
|
|
||||||
if intermediate_size % block_k != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The input_size of down's weight = "
|
|
||||||
f"{intermediate_size} is not divisible by "
|
|
||||||
f"weight quantization block_k = {block_k}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# WEIGHTS
|
|
||||||
w13_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts_per_partition,
|
|
||||||
2 * intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts_per_partition,
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
|
||||||
if self.block_quant:
|
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(
|
|
||||||
num_experts_per_partition,
|
|
||||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
|
||||||
(hidden_size + block_k - 1) // block_k,
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(
|
|
||||||
num_experts_per_partition,
|
|
||||||
(hidden_size + block_n - 1) // block_n,
|
|
||||||
(intermediate_size + block_k - 1) // block_k,
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
|
||||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
|
||||||
else:
|
|
||||||
# WEIGHT_SCALES
|
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
||||||
# Add the quantization method used (per tensor/grouped/channel)
|
|
||||||
# to ensure the weight scales are loaded in properly
|
|
||||||
extra_weight_attrs.update(
|
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
|
||||||
if self.block_quant
|
|
||||||
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
|
||||||
)
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
|
||||||
# process_weights_after_loading()
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
# INPUT_SCALES
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
raise ValueError(
|
|
||||||
"Found static activation scheme for checkpoint that "
|
|
||||||
"was not serialized fp8."
|
|
||||||
)
|
|
||||||
|
|
||||||
w13_input_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
|
||||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_input_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
||||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
layer.w13_input_scale = None
|
|
||||||
layer.w2_input_scale = None
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
# If rocm, use float8_e4m3fnuz as dtype
|
|
||||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
|
||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(
|
|
||||||
layer.num_experts_per_partition,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=w13_weight.device,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
for expert in range(layer.num_experts_per_partition):
|
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
|
||||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
|
||||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
|
||||||
# MoE kernels require single activation scale and single weight
|
|
||||||
# scale for w13 per expert.
|
|
||||||
else:
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None."
|
|
||||||
)
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
|
||||||
torch.max(layer.w13_weight_scale, dim=1).values,
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
if self.block_quant:
|
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
|
||||||
if _is_fp8_fnuz:
|
|
||||||
# activation_scheme: dynamic
|
|
||||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
weight=layer.w13_weight,
|
|
||||||
weight_scale=layer.w13_weight_scale_inv,
|
|
||||||
input_scale=None,
|
|
||||||
)
|
|
||||||
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
|
||||||
weight=layer.w2_weight,
|
|
||||||
weight_scale=layer.w2_weight_scale_inv,
|
|
||||||
input_scale=None,
|
|
||||||
)
|
|
||||||
# Reset the parameter
|
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
|
||||||
w13_weight, requires_grad=False
|
|
||||||
)
|
|
||||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
|
||||||
w13_weight_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
layer.w13_input_scale = None
|
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
|
||||||
w2_weight_scale, requires_grad=False
|
|
||||||
)
|
|
||||||
layer.w2_input_scale = None
|
|
||||||
if _use_aiter:
|
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
|
||||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
|
||||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class DeepEPMoE(EPMoE):
|
class DeepEPMoE(EPMoE):
|
||||||
"""
|
"""
|
||||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||||
FusedMoE,
|
FusedMoE,
|
||||||
FusedMoEMethodBase,
|
|
||||||
FusedMoeWeightScaleSupported,
|
FusedMoeWeightScaleSupported,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,11 +30,9 @@ def get_config() -> Optional[Dict[str, Any]]:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FusedMoE",
|
"FusedMoE",
|
||||||
"FusedMoEMethodBase",
|
|
||||||
"FusedMoeWeightScaleSupported",
|
"FusedMoeWeightScaleSupported",
|
||||||
"override_config",
|
"override_config",
|
||||||
"get_config",
|
"get_config",
|
||||||
"fused_moe",
|
|
||||||
"fused_experts",
|
"fused_experts",
|
||||||
"get_config_file_name",
|
"get_config_file_name",
|
||||||
"moe_align_block_size",
|
"moe_align_block_size",
|
||||||
|
|||||||
@@ -1,60 +1,28 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
||||||
|
|
||||||
import importlib
|
import logging
|
||||||
from abc import abstractmethod
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
|
||||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
|
||||||
cpu_has_amx_support,
|
|
||||||
get_bool_env_var,
|
|
||||||
is_cpu,
|
|
||||||
is_hip,
|
|
||||||
set_weight_attrs,
|
|
||||||
use_intel_amx_backend,
|
|
||||||
)
|
|
||||||
|
|
||||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
|
||||||
|
|
||||||
if has_triton_kernels:
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
|
||||||
triton_kernel_moe_forward,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fused_experts = None # type: ignore
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
||||||
|
|
||||||
if _use_aiter:
|
|
||||||
from aiter import ActivationType
|
|
||||||
from aiter.fused_moe import fused_moe
|
|
||||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -66,333 +34,6 @@ class FusedMoeWeightScaleSupported(Enum):
|
|||||||
BLOCK = "block"
|
BLOCK = "block"
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
num_experts: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
||||||
"""MoE method without quantization."""
|
|
||||||
|
|
||||||
def __init__(self, use_triton_kernels: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.use_triton_kernels = use_triton_kernels
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
num_experts: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
# Fused gate_up_proj (column parallel)
|
|
||||||
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
|
||||||
if self.use_triton_kernels:
|
|
||||||
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
|
||||||
w13_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# down_proj (row parallel)
|
|
||||||
w2_weight_n, w2_weight_k = (
|
|
||||||
hidden_size,
|
|
||||||
intermediate_size,
|
|
||||||
)
|
|
||||||
if self.use_triton_kernels:
|
|
||||||
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
||||||
if _use_aiter:
|
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
|
||||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
layer.w2_weight = torch.nn.Parameter(
|
|
||||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Pack weight for get better performance on CPU
|
|
||||||
if _is_cpu and _is_cpu_amx_available:
|
|
||||||
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
num_fused_shared_experts: int = 0,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return self.forward(
|
|
||||||
x=x,
|
|
||||||
layer=layer,
|
|
||||||
router_logits=router_logits,
|
|
||||||
top_k=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
num_fused_shared_experts=num_fused_shared_experts,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
correction_bias=correction_bias,
|
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
inplace=inplace,
|
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_cuda(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
top_k: int,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
renormalize: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
num_fused_shared_experts: int = 0,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
if self.use_triton_kernels:
|
|
||||||
return triton_kernel_moe_forward(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
gating_output=router_logits,
|
|
||||||
topk=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
top_k=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
num_fused_shared_experts=num_fused_shared_experts,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
correction_bias=correction_bias,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if _use_aiter:
|
|
||||||
assert not no_combine, "unsupported"
|
|
||||||
if apply_router_weight_on_input:
|
|
||||||
assert (
|
|
||||||
topk_weights.dim() == 2
|
|
||||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
|
||||||
_, topk = topk_weights.shape
|
|
||||||
assert (
|
|
||||||
topk == 1
|
|
||||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
|
||||||
x = x * topk_weights.to(x.dtype)
|
|
||||||
topk_weights = torch.ones_like(
|
|
||||||
topk_weights, dtype=torch.float32
|
|
||||||
) # topk_weights must be FP32 (float32)
|
|
||||||
|
|
||||||
return fused_moe(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
activation=(
|
|
||||||
ActivationType.Silu
|
|
||||||
if activation == "silu"
|
|
||||||
else ActivationType.Gelu
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return fused_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
inplace=inplace and not no_combine,
|
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
no_combine=no_combine,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_cpu(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
top_k: int,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
renormalize: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
num_fused_shared_experts: int = 0,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert activation == "silu", f"activation = {activation} is not supported."
|
|
||||||
|
|
||||||
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
|
||||||
topk_weights, topk_ids = select_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
router_logits=router_logits,
|
|
||||||
use_grouped_topk=use_grouped_topk,
|
|
||||||
top_k=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
topk_group=topk_group,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
num_fused_shared_experts=num_fused_shared_experts,
|
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
correction_bias=correction_bias,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
False, # inplace # See [Note] inplace should be False in fused_experts.
|
|
||||||
False, # use_int8_w8a8
|
|
||||||
False, # use_fp8_w8a16
|
|
||||||
None, # w1_scale
|
|
||||||
None, # w2_scale
|
|
||||||
None, # block_size
|
|
||||||
None, # a1_scale
|
|
||||||
None, # a2_scale
|
|
||||||
True, # is_vnni
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return moe_forward_native(
|
|
||||||
layer,
|
|
||||||
x,
|
|
||||||
use_grouped_topk,
|
|
||||||
top_k,
|
|
||||||
router_logits,
|
|
||||||
renormalize,
|
|
||||||
topk_group,
|
|
||||||
num_expert_group,
|
|
||||||
num_fused_shared_experts,
|
|
||||||
custom_routing_function,
|
|
||||||
correction_bias,
|
|
||||||
activation,
|
|
||||||
apply_router_weight_on_input,
|
|
||||||
inplace,
|
|
||||||
no_combine,
|
|
||||||
routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_npu(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
top_k: int,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
renormalize: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
num_fused_shared_experts: int = 0,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return moe_forward_native(
|
|
||||||
layer,
|
|
||||||
x,
|
|
||||||
use_grouped_topk,
|
|
||||||
top_k,
|
|
||||||
router_logits,
|
|
||||||
renormalize,
|
|
||||||
topk_group,
|
|
||||||
num_expert_group,
|
|
||||||
num_fused_shared_experts,
|
|
||||||
custom_routing_function,
|
|
||||||
correction_bias,
|
|
||||||
activation,
|
|
||||||
apply_router_weight_on_input,
|
|
||||||
inplace,
|
|
||||||
no_combine,
|
|
||||||
routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
|
||||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
|
||||||
|
|
||||||
forward_native = forward_cpu
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(torch.nn.Module):
|
class FusedMoE(torch.nn.Module):
|
||||||
"""FusedMoE layer for MoE models.
|
"""FusedMoE layer for MoE models.
|
||||||
|
|
||||||
@@ -553,7 +194,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_dim: int,
|
shard_dim: int,
|
||||||
expert_data: torch.Tensor,
|
expert_data: torch.Tensor,
|
||||||
shard_id: str,
|
shard_id: str,
|
||||||
loaded_weight: torch.tensor,
|
loaded_weight: torch.Tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
):
|
):
|
||||||
# Load grouped weight scales for group quantization
|
# Load grouped weight scales for group quantization
|
||||||
@@ -580,7 +221,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
expert_data: torch.Tensor,
|
expert_data: torch.Tensor,
|
||||||
shard_dim: int,
|
shard_dim: int,
|
||||||
shard_id: str,
|
shard_id: str,
|
||||||
loaded_weight: torch.tensor,
|
loaded_weight: torch.Tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
):
|
):
|
||||||
# for per channel weight quantization
|
# for per channel weight quantization
|
||||||
@@ -600,7 +241,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
expert_data: torch.Tensor,
|
expert_data: torch.Tensor,
|
||||||
shard_dim: int,
|
shard_dim: int,
|
||||||
shard_id: str,
|
shard_id: str,
|
||||||
loaded_weight: torch.tensor,
|
loaded_weight: torch.Tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
):
|
):
|
||||||
|
|
||||||
@@ -645,7 +286,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
expert_data: torch.Tensor,
|
expert_data: torch.Tensor,
|
||||||
shard_dim: int,
|
shard_dim: int,
|
||||||
shard_id: str,
|
shard_id: str,
|
||||||
loaded_weight: torch.tensor,
|
loaded_weight: torch.Tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
):
|
):
|
||||||
"""Load w2 weights for down projection.
|
"""Load w2 weights for down projection.
|
||||||
@@ -717,7 +358,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id: str,
|
shard_id: str,
|
||||||
expert_data: torch.Tensor,
|
expert_data: torch.Tensor,
|
||||||
shard_dim: int,
|
shard_dim: int,
|
||||||
loaded_weight: torch.tensor,
|
loaded_weight: torch.Tensor,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|||||||
@@ -19,15 +19,11 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sglang.srt.eplb import expert_location_dispatch
|
from sglang.srt.eplb import expert_location_dispatch
|
||||||
from sglang.srt.eplb.expert_distribution import (
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
ExpertDistributionRecorder,
|
|
||||||
get_global_expert_distribution_recorder,
|
|
||||||
)
|
|
||||||
from sglang.srt.eplb.expert_location_dispatch import (
|
from sglang.srt.eplb.expert_location_dispatch import (
|
||||||
ExpertLocationDispatchInfo,
|
ExpertLocationDispatchInfo,
|
||||||
topk_ids_logical_to_physical,
|
topk_ids_logical_to_physical,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||||
import builtins
|
import builtins
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Callable, Dict, Optional, Type, Union
|
from typing import Callable, Dict, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -45,7 +43,6 @@ except ImportError:
|
|||||||
) = QQQConfig = Int8TpuConfig = DummyConfig
|
) = QQQConfig = Int8TpuConfig = DummyConfig
|
||||||
|
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
||||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||||
@@ -66,6 +63,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||||
from sglang.srt.layers.quantization.qoq import QoQConfig
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||||
|
from sglang.srt.layers.quantization.utils import (
|
||||||
|
get_dynamic_override,
|
||||||
|
get_linear_quant_method,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
@@ -120,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
return QUANTIZATION_METHODS[quantization]
|
return QUANTIZATION_METHODS[quantization]
|
||||||
|
|
||||||
|
|
||||||
# Match dynamic rules with module name (prefix) and override quantize
|
|
||||||
# config if module (prefix) matches a rule
|
|
||||||
def override_config(config: QuantizationConfig, prefix: str):
|
|
||||||
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
|
||||||
if isinstance(weight_bits, int):
|
|
||||||
config.weight_bits = weight_bits
|
|
||||||
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
|
||||||
if isinstance(group_size, int):
|
|
||||||
config.group_size = group_size
|
|
||||||
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
|
||||||
if isinstance(desc_act, bool):
|
|
||||||
config.desc_act = desc_act
|
|
||||||
|
|
||||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
|
||||||
if config.get_name() == "gptq_marlin":
|
|
||||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
|
||||||
if isinstance(is_sym, bool):
|
|
||||||
config.is_sym = is_sym
|
|
||||||
|
|
||||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupported quantization config: "
|
|
||||||
f"bits={config.weight_bits}, sym={config.is_sym}"
|
|
||||||
)
|
|
||||||
|
|
||||||
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
|
||||||
elif config.get_name() == "gptq":
|
|
||||||
if config.weight_bits not in [2, 3, 4, 8]:
|
|
||||||
raise ValueError(
|
|
||||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
|
||||||
f"supported for GPTQ, but got {config.weight_bits} bits."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_dynamic_override(
|
|
||||||
config: QuantizationConfig,
|
|
||||||
layer_name: str,
|
|
||||||
key: Optional[str] = None,
|
|
||||||
default_value: Union[int, bool, None] = None,
|
|
||||||
) -> Union[Dict, int, bool, None]:
|
|
||||||
for pattern, pattern_dict in config.dynamic.items():
|
|
||||||
# Negative match: matched modules are excluded from quantized init
|
|
||||||
if pattern.startswith("-:"):
|
|
||||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
|
||||||
return False
|
|
||||||
# Positive match: matched modules have quant properties overrides
|
|
||||||
# base quant config
|
|
||||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
|
||||||
if key is None:
|
|
||||||
return pattern_dict
|
|
||||||
else:
|
|
||||||
return pattern_dict.get(key, default_value)
|
|
||||||
return default_value
|
|
||||||
|
|
||||||
|
|
||||||
def get_linear_quant_method(
|
|
||||||
config: QuantizationConfig,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
prefix: str,
|
|
||||||
linear_method_cls: type,
|
|
||||||
):
|
|
||||||
# Move import here to avoid circular import. This is only used in monkey patching
|
|
||||||
# of vllm's QuantizationConfig.
|
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
||||||
ParallelLMHead,
|
|
||||||
UnquantizedEmbeddingMethod,
|
|
||||||
)
|
|
||||||
|
|
||||||
cloned_config = deepcopy(config)
|
|
||||||
parallel_lm_head_quantized = (
|
|
||||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
|
||||||
# False = skip module, None = no override, else = Positive match
|
|
||||||
if (
|
|
||||||
get_dynamic_override( # noqa: E712
|
|
||||||
cloned_config, layer_name=prefix # noqa: E712
|
|
||||||
)
|
|
||||||
== False
|
|
||||||
): # noqa: E712
|
|
||||||
if parallel_lm_head_quantized:
|
|
||||||
return UnquantizedEmbeddingMethod()
|
|
||||||
return UnquantizedLinearMethod()
|
|
||||||
|
|
||||||
if prefix:
|
|
||||||
# Dynamic per module/layer rules may override base config
|
|
||||||
override_config(cloned_config, prefix=prefix)
|
|
||||||
|
|
||||||
return linear_method_cls(cloned_config)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def gptq_get_quant_method(self, layer, prefix):
|
def gptq_get_quant_method(self, layer, prefix):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.linear import (
|
|
||||||
LinearBase,
|
|
||||||
LinearMethodBase,
|
|
||||||
UnquantizedLinearMethod,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
LinearMethodBase,
|
||||||
|
QuantizationConfig,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -81,7 +82,7 @@ class AWQConfig(QuantizationConfig):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
def from_config(cls, config: Dict[str, Any]) -> AWQConfig:
|
||||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||||
@@ -92,7 +93,8 @@ class AWQConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["LinearMethodBase"]:
|
) -> Optional[LinearMethodBase]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||||
|
|||||||
@@ -18,14 +18,14 @@ class QuantizeMethodBase(ABC):
|
|||||||
"""Create weights for a layer.
|
"""Create weights for a layer.
|
||||||
|
|
||||||
The weights will be set as attributes of the layer."""
|
The weights will be set as attributes of the layer."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||||
"""Apply the weights in layer to the input tensor.
|
"""Apply the weights in layer to the input tensor.
|
||||||
|
|
||||||
Expects create_weights to have been called before on the layer."""
|
Expects create_weights to have been called before on the layer."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||||
"""Process the weight after loading.
|
"""Process the weight after loading.
|
||||||
@@ -35,6 +35,74 @@ class QuantizeMethodBase(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class LinearMethodBase(QuantizeMethodBase):
|
||||||
|
"""Base class for different (maybe quantized) linear methods."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
"""Create weights for a linear layer.
|
||||||
|
The weights will be set as attributes of the layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer that is using the LinearMethodBase factory.
|
||||||
|
input_size_per_partition: Size of the weight input dim on rank X.
|
||||||
|
output_partition_sizes: Sizes of the output dim of each logical
|
||||||
|
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
||||||
|
is a list contains the width of Wq, Wk, Wv on rank X.
|
||||||
|
input_size: Size of the input dim of the weight across all ranks.
|
||||||
|
output_size: Size of the output dim of the weight across all ranks.
|
||||||
|
params_dtype: Datatype of the parameters.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply the weights in layer to the input tensor.
|
||||||
|
Expects create_weights to have been called before on the layer."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class QuantizationConfig(ABC):
|
class QuantizationConfig(ABC):
|
||||||
"""Base class for quantization configs."""
|
"""Base class for quantization configs."""
|
||||||
|
|
||||||
@@ -46,12 +114,12 @@ class QuantizationConfig(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
"""Name of the quantization method."""
|
"""Name of the quantization method."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||||
"""List of supported activation dtypes."""
|
"""List of supported activation dtypes."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -62,19 +130,19 @@ class QuantizationConfig(ABC):
|
|||||||
This requirement is due to the custom CUDA kernels used by the
|
This requirement is due to the custom CUDA kernels used by the
|
||||||
quantization method.
|
quantization method.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_config_filenames() -> List[str]:
|
def get_config_filenames() -> List[str]:
|
||||||
"""List of filenames to search for in the model directory."""
|
"""List of filenames to search for in the model directory."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||||
"""Create a config class from the model's quantization config."""
|
"""Create a config class from the model's quantization config."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||||
@@ -117,7 +185,7 @@ class QuantizationConfig(ABC):
|
|||||||
The quantize method. None if the given layer doesn't support quant
|
The quantize method. None if the given layer doesn't support quant
|
||||||
method.
|
method.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
@@ -125,7 +193,7 @@ class QuantizationConfig(ABC):
|
|||||||
|
|
||||||
For now, this is only used by AWQ.
|
For now, this is only used by AWQ.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
|
def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
@@ -7,17 +9,15 @@ import torch
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.linear import (
|
|
||||||
LinearBase,
|
|
||||||
LinearMethodBase,
|
|
||||||
UnquantizedLinearMethod,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
@@ -78,7 +78,7 @@ class BlockInt8Config(QuantizationConfig):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config":
|
def from_config(cls, config: Dict[str, Any]) -> BlockInt8Config:
|
||||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
is_checkpoint_int8_serialized = "int8" in quant_method
|
is_checkpoint_int8_serialized = "int8" in quant_method
|
||||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||||
@@ -93,7 +93,8 @@ class BlockInt8Config(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -230,7 +231,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlockInt8MoEMethod:
|
class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE method for INT8.
|
"""MoE method for INT8.
|
||||||
Supports loading INT8 checkpoints with static weight scale and
|
Supports loading INT8 checkpoints with static weight scale and
|
||||||
dynamic activation scale.
|
dynamic activation scale.
|
||||||
@@ -242,25 +243,7 @@ class BlockInt8MoEMethod:
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __init__(self, quant_config: BlockInt8Config):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
|
||||||
|
|
||||||
if not hasattr(cls, "_initialized"):
|
|
||||||
original_init = cls.__init__
|
|
||||||
new_cls = type(
|
|
||||||
cls.__name__,
|
|
||||||
(FusedMoEMethodBase,),
|
|
||||||
{
|
|
||||||
"__init__": original_init,
|
|
||||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
|
||||||
obj.__init__(*args, **kwargs)
|
|
||||||
return obj
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, quant_config):
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
assert self.quant_config.weight_block_size is not None
|
assert self.quant_config.weight_block_size is not None
|
||||||
assert self.quant_config.is_checkpoint_int8_serialized
|
assert self.quant_config.is_checkpoint_int8_serialized
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
|
|||||||
)
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sglang.srt.layers.linear import (
|
|
||||||
LinearBase,
|
|
||||||
LinearMethodBase,
|
|
||||||
UnquantizedLinearMethod,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
@@ -40,6 +37,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
|||||||
is_activation_quantization_format,
|
is_activation_quantization_format,
|
||||||
should_ignore_layer,
|
should_ignore_layer,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
@@ -97,7 +95,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.packed_modules_mapping = packed_modules_mapping
|
self.packed_modules_mapping = packed_modules_mapping
|
||||||
|
|
||||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
def get_linear_method(self) -> CompressedTensorsLinearMethod:
|
||||||
return CompressedTensorsLinearMethod(self)
|
return CompressedTensorsLinearMethod(self)
|
||||||
|
|
||||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
@@ -117,7 +115,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
|
||||||
# Check if the layer is skipped for quantization.
|
# Check if the layer is skipped for quantization.
|
||||||
# TODO (@robertgshaw2): support module names
|
# TODO (@robertgshaw2): support module names
|
||||||
@@ -138,7 +137,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
|
||||||
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
||||||
quant_format = cast(str, config.get("format"))
|
quant_format = cast(str, config.get("format"))
|
||||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||||
@@ -357,7 +356,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def _get_scheme_from_parts(
|
def _get_scheme_from_parts(
|
||||||
self, weight_quant: BaseModel, input_quant: BaseModel
|
self, weight_quant: BaseModel, input_quant: BaseModel
|
||||||
) -> "CompressedTensorsScheme":
|
) -> CompressedTensorsScheme:
|
||||||
|
|
||||||
# Detect If Mixed Precision
|
# Detect If Mixed Precision
|
||||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
@@ -435,7 +434,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_scheme(
|
def get_scheme(
|
||||||
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
||||||
) -> Optional["CompressedTensorsScheme"]:
|
) -> Optional[CompressedTensorsScheme]:
|
||||||
"""
|
"""
|
||||||
compressed-tensors supports non uniform in the following way:
|
compressed-tensors supports non uniform in the following way:
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -28,17 +30,14 @@ except ImportError:
|
|||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.linear import (
|
|
||||||
LinearBase,
|
|
||||||
LinearMethodBase,
|
|
||||||
UnquantizedLinearMethod,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
@@ -56,6 +55,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
@@ -77,6 +77,9 @@ from sglang.srt.utils import (
|
|||||||
use_intel_amx_backend,
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
@@ -152,7 +155,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
|
||||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||||
@@ -167,7 +170,8 @@ class Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -200,7 +204,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
|
def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
@@ -486,7 +490,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8MoEMethod:
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE method for FP8.
|
"""MoE method for FP8.
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
dynamic/static activation scale.
|
dynamic/static activation scale.
|
||||||
@@ -499,25 +503,7 @@ class Fp8MoEMethod:
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
|
||||||
|
|
||||||
if not hasattr(cls, "_initialized"):
|
|
||||||
original_init = cls.__init__
|
|
||||||
new_cls = type(
|
|
||||||
cls.__name__,
|
|
||||||
(FusedMoEMethodBase,),
|
|
||||||
{
|
|
||||||
"__init__": original_init,
|
|
||||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
|
||||||
obj.__init__(*args, **kwargs)
|
|
||||||
return obj
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, quant_config):
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
@@ -1169,6 +1155,254 @@ class Fp8MoEMethod:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||||
|
"""MoE method for FP8.
|
||||||
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
|
dynamic/static activation scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: Fp8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: Module,
|
||||||
|
num_experts_per_partition: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if self.block_quant:
|
||||||
|
block_n, block_k = (
|
||||||
|
self.quant_config.weight_block_size[0],
|
||||||
|
self.quant_config.weight_block_size[1],
|
||||||
|
)
|
||||||
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||||
|
# Required by column parallel or enabling merged weights
|
||||||
|
if intermediate_size % block_n != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The output_size of gate's and up's weight = "
|
||||||
|
f"{intermediate_size} is not divisible by "
|
||||||
|
f"weight quantization block_n = {block_n}."
|
||||||
|
)
|
||||||
|
if tp_size > 1:
|
||||||
|
# Required by row parallel
|
||||||
|
if intermediate_size % block_k != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input_size of down's weight = "
|
||||||
|
f"{intermediate_size} is not divisible by "
|
||||||
|
f"weight quantization block_k = {block_k}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
if self.block_quant:
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
num_experts_per_partition,
|
||||||
|
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||||
|
(hidden_size + block_k - 1) // block_k,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
num_experts_per_partition,
|
||||||
|
(hidden_size + block_n - 1) // block_n,
|
||||||
|
(intermediate_size + block_k - 1) // block_k,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||||
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||||
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
|
else:
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
|
# to ensure the weight scales are loaded in properly
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||||
|
if self.block_quant
|
||||||
|
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||||
|
)
|
||||||
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
|
# process_weights_after_loading()
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# INPUT_SCALES
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
raise ValueError(
|
||||||
|
"Found static activation scheme for checkpoint that "
|
||||||
|
"was not serialized fp8."
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_input_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_input_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(num_experts_per_partition, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
|
||||||
|
# If checkpoint is fp16, quantize in place.
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
# If rocm, use float8_e4m3fnuz as dtype
|
||||||
|
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||||
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
|
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.ones(
|
||||||
|
layer.num_experts_per_partition,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=w13_weight.device,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for expert in range(layer.num_experts_per_partition):
|
||||||
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
|
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
|
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||||
|
)
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If checkpoint is fp8, we need to handle that the
|
||||||
|
# MoE kernels require single activation scale and single weight
|
||||||
|
# scale for w13 per expert.
|
||||||
|
else:
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||||
|
raise ValueError(
|
||||||
|
"QuantConfig has static quantization, but found "
|
||||||
|
"activation scales are None."
|
||||||
|
)
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.max(layer.w13_weight_scale, dim=1).values,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
if self.block_quant:
|
||||||
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if _is_fp8_fnuz:
|
||||||
|
# activation_scheme: dynamic
|
||||||
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=layer.w13_weight,
|
||||||
|
weight_scale=layer.w13_weight_scale_inv,
|
||||||
|
input_scale=None,
|
||||||
|
)
|
||||||
|
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=layer.w2_weight,
|
||||||
|
weight_scale=layer.w2_weight_scale_inv,
|
||||||
|
input_scale=None,
|
||||||
|
)
|
||||||
|
# Reset the parameter
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
w13_weight, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||||
|
w13_weight_scale, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||||
|
w2_weight_scale, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
if _use_aiter:
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
"""
|
"""
|
||||||
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
@@ -5,7 +7,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs
|
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
@@ -16,6 +17,8 @@ from sglang.srt.layers.parameter import (
|
|||||||
permute_param_layout_,
|
permute_param_layout_,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
@@ -34,7 +37,11 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|||||||
verify_marlin_supported,
|
verify_marlin_supported,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||||
from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols
|
from sglang.srt.layers.quantization.utils import (
|
||||||
|
get_linear_quant_method,
|
||||||
|
replace_parameter,
|
||||||
|
unpack_cols,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@@ -49,8 +56,6 @@ if _is_cuda:
|
|||||||
from sgl_kernel import fused_marlin_moe
|
from sgl_kernel import fused_marlin_moe
|
||||||
|
|
||||||
|
|
||||||
FusedMoEMethodBase = QuantizeMethodBase
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -179,7 +184,7 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
return ["quantize_config.json"]
|
return ["quantize_config.json"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
|
||||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||||
dynamic = {} if dynamic is None else dynamic
|
dynamic = {} if dynamic is None else dynamic
|
||||||
|
|
||||||
@@ -191,10 +196,10 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["LinearMethodBase"]:
|
) -> Optional[LinearMethodBase]:
|
||||||
# Delay the import to avoid circular dependency
|
# Delay the import to avoid circular dependency
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||||
@@ -303,7 +308,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
return ["quantize_config.json"]
|
return ["quantize_config.json"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
def from_config(cls, config: Dict[str, Any]) -> GPTQMarlinConfig:
|
||||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||||
dynamic = {} if dynamic is None else dynamic
|
dynamic = {} if dynamic is None else dynamic
|
||||||
|
|
||||||
@@ -354,7 +359,6 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
) -> Optional[QuantizeMethodBase]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
# Delay the import to avoid circular dependency
|
# Delay the import to avoid circular dependency
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
|
||||||
|
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
return GPTQMarlinMoEMethod(self)
|
return GPTQMarlinMoEMethod(self)
|
||||||
@@ -832,6 +836,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
# Delay the import to avoid circular dependency
|
# Delay the import to avoid circular dependency
|
||||||
|
from sglang.srt.layers.linear import set_weight_attrs
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
||||||
|
|||||||
@@ -1,25 +1,31 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
PackedvLLMParameter,
|
PackedvLLMParameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
LinearMethodBase,
|
||||||
|
QuantizationConfig,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||||
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
|
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
||||||
from sglang.srt.utils import get_device_capability
|
from sglang.srt.utils import get_device_capability
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -617,7 +623,10 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["MarlinLinearMethod"]:
|
) -> Optional[MarlinLinearMethod]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
|
||||||
if isinstance(layer, LinearBase) or (
|
if isinstance(layer, LinearBase) or (
|
||||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
@@ -6,14 +7,11 @@ from typing import Any, Callable, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.linear import (
|
|
||||||
LinearBase,
|
|
||||||
LinearMethodBase,
|
|
||||||
UnquantizedLinearMethod,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
@@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
is_sm100_supported,
|
is_sm100_supported,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
@@ -86,7 +85,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
return ["hf_quant_config.json"]
|
return ["hf_quant_config.json"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
|
||||||
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
|
||||||
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
|
||||||
"kv_cache_quant_algo"
|
"kv_cache_quant_algo"
|
||||||
@@ -109,7 +108,11 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if self.exclude_modules and any(
|
if self.exclude_modules and any(
|
||||||
module in prefix
|
module in prefix
|
||||||
or (
|
or (
|
||||||
@@ -125,9 +128,6 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
|
|
||||||
# Add MoE support
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
||||||
|
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
return ModelOptFp8MoEMethod(self)
|
return ModelOptFp8MoEMethod(self)
|
||||||
|
|
||||||
@@ -246,7 +246,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
|||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8MoEMethod:
|
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE method for ModelOpt FP8.
|
"""MoE method for ModelOpt FP8.
|
||||||
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
Supports loading FP8 checkpoints with static weight scale and activation scale.
|
||||||
|
|
||||||
@@ -254,30 +254,6 @@ class ModelOptFp8MoEMethod:
|
|||||||
quant_config: The ModelOpt quantization config.
|
quant_config: The ModelOpt quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Dynamic class composition pattern.
|
|
||||||
|
|
||||||
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
|
|
||||||
at runtime while avoiding circular import issues.
|
|
||||||
"""
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
|
||||||
|
|
||||||
if not hasattr(cls, "_initialized"):
|
|
||||||
original_init = cls.__init__
|
|
||||||
new_cls = type(
|
|
||||||
cls.__name__,
|
|
||||||
(FusedMoEMethodBase,),
|
|
||||||
{
|
|
||||||
"__init__": original_init,
|
|
||||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
|
||||||
obj.__init__(*args, **kwargs)
|
|
||||||
return obj
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp8Config):
|
def __init__(self, quant_config: ModelOptFp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
@@ -514,7 +490,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
return ["hf_quant_config.json"]
|
return ["hf_quant_config.json"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
||||||
quant_config = cls.get_from_keys(config, ["quantization"])
|
quant_config = cls.get_from_keys(config, ["quantization"])
|
||||||
quant_method = quant_config["quant_algo"]
|
quant_method = quant_config["quant_algo"]
|
||||||
if not quant_method in ["FP8", "NVFP4"]:
|
if not quant_method in ["FP8", "NVFP4"]:
|
||||||
@@ -559,7 +535,8 @@ class ModelOptFp4Config(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -740,31 +717,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|||||||
return out.view(*output_shape)
|
return out.view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
class ModelOptNvFp4FusedMoEMethod:
|
class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||||
"""
|
"""
|
||||||
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
|
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
|
||||||
Args:
|
Args:
|
||||||
quant_config: NVFP4 Quant Config
|
quant_config: NVFP4 Quant Config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
|
||||||
|
|
||||||
if not hasattr(cls, "_initialized"):
|
|
||||||
original_init = cls.__init__
|
|
||||||
new_cls = type(
|
|
||||||
cls.__name__,
|
|
||||||
(FusedMoEMethodBase,),
|
|
||||||
{
|
|
||||||
"__init__": original_init,
|
|
||||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
|
||||||
obj.__init__(*args, **kwargs)
|
|
||||||
return obj
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp4Config):
|
def __init__(self, quant_config: ModelOptFp4Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
if not is_sm100_supported():
|
if not is_sm100_supported():
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
@@ -7,13 +8,14 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
||||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.utils import get_device_capability, set_weight_attrs
|
from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -118,7 +120,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
|
def from_config(cls, config: Dict[str, Any]) -> MoeWNA16Config:
|
||||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
group_size = cls.get_from_keys(config, ["group_size"])
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
@@ -177,8 +179,9 @@ class MoeWNA16Config(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
# avoid circular import
|
# avoid circular import
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||||
@@ -209,32 +212,13 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
|
|||||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||||
|
|
||||||
|
|
||||||
class MoeWNA16Method:
|
class MoeWNA16Method(FusedMoEMethodBase):
|
||||||
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
# avoid circular import
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
|
||||||
|
|
||||||
if not hasattr(cls, "_initialized"):
|
|
||||||
original_init = cls.__init__
|
|
||||||
new_cls = type(
|
|
||||||
cls.__name__,
|
|
||||||
(FusedMoEMethodBase,),
|
|
||||||
{
|
|
||||||
"__init__": original_init,
|
|
||||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
|
||||||
obj.__init__(*args, **kwargs)
|
|
||||||
return obj
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, quant_config: MoeWNA16Config):
|
def __init__(self, quant_config: MoeWNA16Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
||||||
from sglang.srt.layers.linear import LinearMethodBase
|
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
@@ -71,7 +72,7 @@ class QoQConfig(QuantizationConfig):
|
|||||||
return 80
|
return 80
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(self) -> str:
|
def get_name(cls) -> str:
|
||||||
return "qoq"
|
return "qoq"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -83,7 +84,7 @@ class QoQConfig(QuantizationConfig):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "QoQConfig":
|
def from_config(cls, config: Dict[str, Any]) -> QoQConfig:
|
||||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||||
group_size = cls.get_from_keys(config, ["group_size"])
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
return cls(weight_bits, group_size)
|
return cls(weight_bits, group_size)
|
||||||
@@ -92,7 +93,7 @@ class QoQConfig(QuantizationConfig):
|
|||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
|
|||||||
515
python/sglang/srt/layers/quantization/unquant.py
Normal file
515
python/sglang/srt/layers/quantization/unquant.py
Normal file
@@ -0,0 +1,515 @@
|
|||||||
|
import importlib
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from sglang.srt.custom_op import CustomOp
|
||||||
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
LinearMethodBase,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
cpu_has_amx_support,
|
||||||
|
get_bool_env_var,
|
||||||
|
is_cpu,
|
||||||
|
is_hip,
|
||||||
|
set_weight_attrs,
|
||||||
|
use_intel_amx_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||||
|
|
||||||
|
|
||||||
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
|
_is_hip = is_hip()
|
||||||
|
_is_cpu = is_cpu()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
from aiter import ActivationType
|
||||||
|
from aiter.fused_moe import fused_moe
|
||||||
|
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||||
|
"""Unquantized method for embeddings."""
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
"""Create weights for embedding layer."""
|
||||||
|
weight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
set_weight_attrs(weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return F.linear(x, layer.weight, bias)
|
||||||
|
|
||||||
|
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.embedding(input_, layer.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedLinearMethod(LinearMethodBase):
|
||||||
|
"""Linear method without quantization."""
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
weight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
set_weight_attrs(weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if _is_cpu and _is_cpu_amx_available:
|
||||||
|
_amx_process_weight_after_loading(layer, ["weight"])
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if use_intel_amx_backend(layer):
|
||||||
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||||
|
x, layer.weight, bias, True # is_vnni
|
||||||
|
)
|
||||||
|
|
||||||
|
return F.linear(x, layer.weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
|
def __init__(self, use_triton_kernels: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.use_triton_kernels = use_triton_kernels
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
|
if has_triton_kernels:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||||
|
triton_kernel_moe_forward,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
triton_kernel_moe_forward = None
|
||||||
|
else:
|
||||||
|
fused_experts = None # type: ignore
|
||||||
|
triton_kernel_moe_forward = None
|
||||||
|
|
||||||
|
self.moe_forward_native = moe_forward_native
|
||||||
|
self.fused_experts = fused_experts
|
||||||
|
self.triton_kernel_moe_forward = triton_kernel_moe_forward
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight_n, w2_weight_k = (
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
)
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if _use_aiter:
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Pack weight for get better performance on CPU
|
||||||
|
if _is_cpu and _is_cpu_amx_available:
|
||||||
|
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
return self.forward(
|
||||||
|
x=x,
|
||||||
|
layer=layer,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
inplace=inplace,
|
||||||
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if self.use_triton_kernels:
|
||||||
|
return self.triton_kernel_moe_forward(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
assert not no_combine, "unsupported"
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
assert (
|
||||||
|
topk_weights.dim() == 2
|
||||||
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
|
_, topk = topk_weights.shape
|
||||||
|
assert (
|
||||||
|
topk == 1
|
||||||
|
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
|
x = x * topk_weights.to(x.dtype)
|
||||||
|
topk_weights = torch.ones_like(
|
||||||
|
topk_weights, dtype=torch.float32
|
||||||
|
) # topk_weights must be FP32 (float32)
|
||||||
|
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu
|
||||||
|
if activation == "silu"
|
||||||
|
else ActivationType.Gelu
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=inplace and not no_combine,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_cpu(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert activation == "silu", f"activation = {activation} is not supported."
|
||||||
|
|
||||||
|
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
||||||
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
False, # inplace # See [Note] inplace should be False in fused_experts.
|
||||||
|
False, # use_int8_w8a8
|
||||||
|
False, # use_fp8_w8a16
|
||||||
|
None, # w1_scale
|
||||||
|
None, # w2_scale
|
||||||
|
None, # block_size
|
||||||
|
None, # a1_scale
|
||||||
|
None, # a2_scale
|
||||||
|
True, # is_vnni
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.moe_forward_native(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
use_grouped_topk,
|
||||||
|
top_k,
|
||||||
|
router_logits,
|
||||||
|
renormalize,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
num_fused_shared_experts,
|
||||||
|
custom_routing_function,
|
||||||
|
correction_bias,
|
||||||
|
activation,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
inplace,
|
||||||
|
no_combine,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_npu(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.moe_forward_native(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
use_grouped_topk,
|
||||||
|
top_k,
|
||||||
|
router_logits,
|
||||||
|
renormalize,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
num_fused_shared_experts,
|
||||||
|
custom_routing_function,
|
||||||
|
correction_bias,
|
||||||
|
activation,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
inplace,
|
||||||
|
no_combine,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||||
|
|
||||||
|
forward_native = forward_cpu
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts_per_partition: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts_per_partition,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# scale
|
||||||
|
layer.register_parameter("w13_input_scale", None)
|
||||||
|
layer.register_parameter("w13_weight_scale", None)
|
||||||
|
|
||||||
|
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
||||||
|
|
||||||
|
w2_input_scale = torch.nn.Parameter(
|
||||||
|
ones_tensor,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
ones_tensor,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
@@ -1,7 +1,11 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import List, Mapping, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
@@ -10,6 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
|||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
||||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_npu = is_npu()
|
_is_npu = is_npu()
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
@@ -147,6 +154,94 @@ def replace_parameter(
|
|||||||
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
||||||
|
|
||||||
|
|
||||||
|
# Match dynamic rules with module name (prefix) and override quantize
|
||||||
|
# config if module (prefix) matches a rule
|
||||||
|
def override_config(config: QuantizationConfig, prefix: str):
|
||||||
|
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
||||||
|
if isinstance(weight_bits, int):
|
||||||
|
config.weight_bits = weight_bits
|
||||||
|
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
||||||
|
if isinstance(group_size, int):
|
||||||
|
config.group_size = group_size
|
||||||
|
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
||||||
|
if isinstance(desc_act, bool):
|
||||||
|
config.desc_act = desc_act
|
||||||
|
|
||||||
|
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||||
|
if config.get_name() == "gptq_marlin":
|
||||||
|
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||||
|
if isinstance(is_sym, bool):
|
||||||
|
config.is_sym = is_sym
|
||||||
|
|
||||||
|
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported quantization config: "
|
||||||
|
f"bits={config.weight_bits}, sym={config.is_sym}"
|
||||||
|
)
|
||||||
|
|
||||||
|
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
||||||
|
elif config.get_name() == "gptq":
|
||||||
|
if config.weight_bits not in [2, 3, 4, 8]:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||||
|
f"supported for GPTQ, but got {config.weight_bits} bits."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dynamic_override(
|
||||||
|
config: QuantizationConfig,
|
||||||
|
layer_name: str,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
default_value: Union[int, bool, None] = None,
|
||||||
|
) -> Union[Dict, int, bool, None]:
|
||||||
|
for pattern, pattern_dict in config.dynamic.items():
|
||||||
|
# Negative match: matched modules are excluded from quantized init
|
||||||
|
if pattern.startswith("-:"):
|
||||||
|
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||||
|
return False
|
||||||
|
# Positive match: matched modules have quant properties overrides
|
||||||
|
# base quant config
|
||||||
|
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||||
|
if key is None:
|
||||||
|
return pattern_dict
|
||||||
|
else:
|
||||||
|
return pattern_dict.get(key, default_value)
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear_quant_method(
|
||||||
|
config: QuantizationConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
linear_method_cls: type,
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
from sglang.srt.layers.quantization.unquant import (
|
||||||
|
UnquantizedEmbeddingMethod,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
|
||||||
|
cloned_config = deepcopy(config)
|
||||||
|
parallel_lm_head_quantized = (
|
||||||
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||||
|
# False = skip module, None = no override, else = Positive match
|
||||||
|
if get_dynamic_override(cloned_config, layer_name=prefix) is False:
|
||||||
|
if parallel_lm_head_quantized:
|
||||||
|
return UnquantizedEmbeddingMethod()
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
# Dynamic per module/layer rules may override base config
|
||||||
|
override_config(cloned_config, prefix=prefix)
|
||||||
|
|
||||||
|
return linear_method_cls(cloned_config)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_pack_factor(num_bits):
|
def get_pack_factor(num_bits):
|
||||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||||
return 32 // num_bits
|
return 32 // num_bits
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
@@ -5,12 +7,13 @@ import torch
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
@@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
|
def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
|
||||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||||
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
||||||
@@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module, prefix: str
|
self, layer: torch.nn.Module, prefix: str
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class W4AFp8MoEMethod:
|
class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
def __init__(self, quant_config: W4AFp8Config):
|
def __init__(self, quant_config: W4AFp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearMethodBase
|
|
||||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
@@ -64,7 +67,7 @@ class W8A8Fp8Config(QuantizationConfig):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
def from_config(cls, config: Dict[str, Any]) -> W8A8Fp8Config:
|
||||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
is_checkpoint_fp8_serialized = (
|
is_checkpoint_fp8_serialized = (
|
||||||
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
|
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
|
||||||
@@ -75,7 +78,7 @@ class W8A8Fp8Config(QuantizationConfig):
|
|||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
@@ -183,7 +186,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class W8A8FP8MoEMethod:
|
class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE method for FP8.
|
"""MoE method for FP8.
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
dynamic/static activation scale.
|
dynamic/static activation scale.
|
||||||
@@ -194,25 +197,7 @@ class W8A8FP8MoEMethod:
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __init__(self, quant_config: W8A8Fp8Config):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
|
||||||
|
|
||||||
if not hasattr(cls, "_initialized"):
|
|
||||||
original_init = cls.__init__
|
|
||||||
new_cls = type(
|
|
||||||
cls.__name__,
|
|
||||||
(FusedMoEMethodBase,),
|
|
||||||
{
|
|
||||||
"__init__": original_init,
|
|
||||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
|
||||||
obj.__init__(*args, **kwargs)
|
|
||||||
return obj
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, quant_config):
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
@@ -11,21 +13,19 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.linear import (
|
|
||||||
LinearMethodBase,
|
|
||||||
RowParallelLinear,
|
|
||||||
UnquantizedLinearMethod,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
ChannelQuantScaleParameter,
|
ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
apply_module_patch,
|
apply_module_patch,
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
@@ -229,14 +229,14 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
|
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
|
||||||
return cls(config)
|
return cls(config)
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional[QuantizeMethodBase]:
|
||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
@@ -374,7 +374,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class W8A8Int8MoEMethod:
|
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE method for INT8.
|
"""MoE method for INT8.
|
||||||
Supports loading INT8 checkpoints with static weight scale and
|
Supports loading INT8 checkpoints with static weight scale and
|
||||||
dynamic/static activation scale.
|
dynamic/static activation scale.
|
||||||
@@ -385,25 +385,7 @@ class W8A8Int8MoEMethod:
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __init__(self, quant_config: W8A8Int8Config):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
|
||||||
|
|
||||||
if not hasattr(cls, "_initialized"):
|
|
||||||
original_init = cls.__init__
|
|
||||||
new_cls = type(
|
|
||||||
cls.__name__,
|
|
||||||
(FusedMoEMethodBase,),
|
|
||||||
{
|
|
||||||
"__init__": original_init,
|
|
||||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
|
||||||
obj.__init__(*args, **kwargs)
|
|
||||||
return obj
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, quant_config):
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
@@ -885,13 +867,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.linear import RowParallelLinear
|
||||||
|
|
||||||
if isinstance(layer, RowParallelLinear):
|
if isinstance(layer, RowParallelLinear):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||||
return self.quant_method.apply(layer, x, bias)
|
return self.quant_method.apply(layer, x, bias)
|
||||||
|
|
||||||
|
|
||||||
class NPU_W8A8MoEMethod:
|
class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE method for NPU quantization.
|
"""MoE method for NPU quantization.
|
||||||
|
|
||||||
This class search for specific quantization
|
This class search for specific quantization
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Sequence, Tuple
|
from typing import List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
method_has_implemented_embedding,
|
method_has_implemented_embedding,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
|
||||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
||||||
|
|
||||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||||
@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
|
||||||
"""Unquantized method for embeddings."""
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
"""Create weights for embedding layer."""
|
|
||||||
weight = Parameter(
|
|
||||||
torch.empty(
|
|
||||||
sum(output_partition_sizes),
|
|
||||||
input_size_per_partition,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
set_weight_attrs(weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return F.linear(x, layer.weight, bias)
|
|
||||||
|
|
||||||
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
|
||||||
return F.embedding(input_, layer.weight)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||||
"""Pad the vocab size to the given value."""
|
"""Pad the vocab size to the given value."""
|
||||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||||
|
|||||||
Reference in New Issue
Block a user