[QuickFix] fix gptq model initialize (#6429)
This commit is contained in:
@@ -25,7 +25,6 @@ try:
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQMarlin24Config,
|
||||
@@ -58,7 +57,11 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.gptq import (
|
||||
GPTQConfig,
|
||||
GPTQMarlinConfig,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
ModelOptFp4Config,
|
||||
ModelOptFp8Config,
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
import logging
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import replace_parameter
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
marlin_moe_permute_scales,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@@ -27,7 +34,9 @@ try:
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
|
||||
GPTQLinearMethod = MarlinLinearMethod = Any
|
||||
|
||||
FusedMoEMethodBase = QuantizeMethodBase
|
||||
|
||||
class scalar_types:
|
||||
uint4b8 = "uint4b8"
|
||||
@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig):
|
||||
):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE Marlin method with quantization."""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
intermediate_size_per_partition == intermediate_size
|
||||
)
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
scales_size13 = hidden_size // self.quant_config.group_size
|
||||
w2_scales_size = (
|
||||
intermediate_size
|
||||
if self.quant_config.desc_act
|
||||
else intermediate_size_per_partition
|
||||
)
|
||||
scales_size2 = w2_scales_size // self.quant_config.group_size
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
else:
|
||||
scales_size13 = 1
|
||||
scales_size2 = 1
|
||||
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
|
||||
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
# up_proj scales
|
||||
w13_scales = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.half,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts, scales_size2, hidden_size, dtype=torch.half),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
# dont shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
|
||||
# up_proj scales
|
||||
w13_qzeros = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_qzeros = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
scales_size2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
# dont shard the w2 scales when running act order
|
||||
set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1]
|
||||
* (
|
||||
self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1
|
||||
else self.quant_config.pack_factor
|
||||
),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_k_full=self.is_k_full,
|
||||
).to(orig_dtype)
|
||||
|
||||
Reference in New Issue
Block a user