238 lines
11 KiB
Python
238 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any, Callable, Optional
|
|
|
|
import torch
|
|
|
|
import vllm.model_executor.layers.fused_moe # noqa
|
|
from vllm import _custom_ops as ops
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
|
FusedMoeWeightScaleSupported)
|
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"]
|
|
|
|
|
|
class QuarkMoEMethod(FusedMoEMethodBase):
|
|
|
|
@staticmethod
|
|
def get_moe_method(
|
|
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
|
module: torch.nn.Module,
|
|
layer_name: str) -> "QuarkMoEMethod":
|
|
layer_quant_config = quant_config._find_matched_config(
|
|
layer_name, module)
|
|
|
|
if (layer_quant_config.get("output_tensors")
|
|
or layer_quant_config.get("bias")):
|
|
raise NotImplementedError("Currently, Quark models with "
|
|
"output_tensors and bias "
|
|
"quantized are not supported")
|
|
weight_config = layer_quant_config.get("weight")
|
|
input_config = layer_quant_config.get("input_tensors")
|
|
|
|
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
|
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
|
|
else:
|
|
raise RuntimeError("Unsupported FusedMoe scheme")
|
|
|
|
|
|
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|
|
|
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
|
|
Any]):
|
|
self.weight_quant = weight_config
|
|
self.input_quant = input_config
|
|
|
|
weight_qscheme = self.weight_quant.get("qscheme")
|
|
input_qscheme = self.input_quant.get("qscheme")
|
|
if not (weight_qscheme == "per_tensor"
|
|
and input_qscheme == "per_tensor"):
|
|
raise ValueError(
|
|
"For FP8 Fused MoE layers, only per-tensor scales "
|
|
"for weights and activations are supported. Found "
|
|
f"{weight_qscheme}, {input_qscheme}") # noqa E501
|
|
|
|
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
|
|
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
|
hidden_size: int, intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
|
|
|
params_dtype = torch.float8_e4m3fn
|
|
|
|
# WEIGHTS
|
|
w13_weight = torch.nn.Parameter(torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_size,
|
|
dtype=params_dtype),
|
|
requires_grad=False)
|
|
layer.register_parameter("w13_weight", w13_weight)
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
|
|
w2_weight = torch.nn.Parameter(torch.empty(
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size_per_partition,
|
|
dtype=params_dtype),
|
|
requires_grad=False)
|
|
layer.register_parameter("w2_weight", w2_weight)
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
|
|
# WEIGHT_SCALES
|
|
# Allocate 2 scales for w1 and w3 respectively.
|
|
# They will be combined to a single scale after weight loading.
|
|
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
|
2,
|
|
dtype=torch.float32),
|
|
requires_grad=False)
|
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
|
|
|
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
|
dtype=torch.float32),
|
|
requires_grad=False)
|
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
|
# Add the quantization method used (per tensor/grouped/channel)
|
|
# to ensure the weight scales are loaded in properly
|
|
extra_weight_attrs.update(
|
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
|
|
|
# INPUT_SCALES
|
|
if self.static_input_scales:
|
|
w13_input_scale = torch.nn.Parameter(torch.ones(
|
|
num_experts, dtype=torch.float32),
|
|
requires_grad=False)
|
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
|
|
|
w2_input_scale = torch.nn.Parameter(torch.ones(
|
|
num_experts, dtype=torch.float32),
|
|
requires_grad=False)
|
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
|
else:
|
|
layer.w13_input_scale = None
|
|
layer.w2_input_scale = None
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
# Fp8 moe kernels require a single activation scale.
|
|
# We take the max of all the scales in case they differ.
|
|
if self.static_input_scales:
|
|
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
|
raise ValueError(
|
|
"QuantConfig has static quantization, but found "
|
|
"activation scales are None.")
|
|
if (not all_close_1d(layer.w13_input_scale)
|
|
or not all_close_1d(layer.w2_input_scale)):
|
|
logger.warning_once(
|
|
"Found input_scales that are not equal for "
|
|
"fp8 MoE layer. Using the maximum across experts "
|
|
"for each layer. ")
|
|
layer.w13_input_scale = torch.nn.Parameter(
|
|
layer.w13_input_scale.max(), requires_grad=False)
|
|
layer.w2_input_scale = torch.nn.Parameter(
|
|
layer.w2_input_scale.max(), requires_grad=False)
|
|
|
|
if current_platform.is_fp8_fnuz():
|
|
# Normalize the weights and scales
|
|
w13_weight, w13_weight_scale, w13_input_scale = \
|
|
normalize_e4m3fn_to_e4m3fnuz(
|
|
layer.w13_weight, layer.w13_weight_scale,
|
|
layer.w13_input_scale)
|
|
w2_weight, w2_weight_scale, w2_input_scale = \
|
|
normalize_e4m3fn_to_e4m3fnuz(
|
|
layer.w2_weight, layer.w2_weight_scale,
|
|
layer.w2_input_scale)
|
|
# Reset the parameter
|
|
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
|
requires_grad=False)
|
|
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
|
|
requires_grad=False)
|
|
if w13_input_scale is not None:
|
|
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
|
|
requires_grad=False)
|
|
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
|
requires_grad=False)
|
|
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
|
requires_grad=False)
|
|
if w2_input_scale is not None:
|
|
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
|
requires_grad=False)
|
|
|
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
|
# We take the max then dequant and requant each expert.
|
|
assert layer.w13_weight_scale is not None
|
|
shard_size = layer.intermediate_size_per_partition
|
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
|
for expert_id in range(layer.local_num_experts):
|
|
start = 0
|
|
for shard_id in range(2):
|
|
dq_weight = per_tensor_dequantize(
|
|
layer.w13_weight[expert_id][start:start + shard_size, :],
|
|
layer.w13_weight_scale[expert_id][shard_id])
|
|
layer.w13_weight[expert_id][
|
|
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
|
dq_weight, max_w13_scales[expert_id])
|
|
start += shard_size
|
|
|
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
|
requires_grad=False)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
) -> torch.Tensor:
|
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
|
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias)
|
|
|
|
return fused_experts(
|
|
x,
|
|
layer.w13_weight,
|
|
layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=True,
|
|
use_fp8_w8a8=True,
|
|
global_num_experts=global_num_experts,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
expert_map=expert_map,
|
|
w1_scale=layer.w13_weight_scale,
|
|
w2_scale=layer.w2_weight_scale,
|
|
a1_scale=layer.w13_input_scale,
|
|
a2_scale=layer.w2_input_scale)
|