Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -18,6 +18,7 @@ QuantizationMethods = Literal[
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"modelopt_mixed",
"gguf",
"gptq_marlin",
"awq_marlin",
@@ -32,6 +33,7 @@ QuantizationMethods = Literal[
"mxfp4",
"petit_nvfp4",
"cpu_awq",
"w8a16"
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -120,12 +122,18 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .inc import INCConfig
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
from .modelopt import (
ModelOptFp8Config,
ModelOptMixedPrecisionConfig,
ModelOptMxFp8Config,
ModelOptNvFp4Config,
)
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .torchao import TorchAOConfig
from .w8a16 import W8a16Config
method_to_config: dict[str, type[QuantizationConfig]] = {
"awq": AWQConfig,
@@ -135,6 +143,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config,
"modelopt_mxfp8": ModelOptMxFp8Config,
"modelopt_mixed": ModelOptMixedPrecisionConfig,
"gguf": GGUFConfig,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
@@ -151,6 +160,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig,
"w8a16": W8a16Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
@@ -9,6 +9,7 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -60,6 +61,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
@@ -197,7 +199,7 @@ class AWQMarlinConfig(QuantizationConfig):
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
# from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
# if is_layer_skipped(
# prefix,
@@ -213,9 +215,10 @@ class AWQMarlinConfig(QuantizationConfig):
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
# layer, prefix
# )
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
# moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
# moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
# return moe_quant_method
return AWQMarlinMoEMethod(self, layer.moe_config)
return None
@classmethod
@@ -389,13 +392,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
replace_parameter(layer, "qweight", pad_qweight)
replace_parameter(layer, "qzeros", pad_qzeros)
replace_parameter(layer, "scales", pad_scales)
return
# TODO(gyf) Marlin format is not support for now..
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
return
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
@@ -811,49 +814,33 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# return fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# getattr(layer, "w13_bias", None),
# getattr(layer, "w2_bias", None),
# layer.w13_scales,
# layer.w2_scales,
# topk_weights,
# topk_ids,
# input_global_scale1=getattr(layer, "w13_input_global_scale", None),
# input_global_scale2=getattr(layer, "w2_input_global_scale", None),
# quant_type_id=self.quant_type.id,
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
# global_num_experts=layer.global_num_experts,
# expert_map=layer.expert_map,
# w1_zeros=layer.w13_qzeros,
# w2_zeros=layer.w2_qzeros,
# workspace=layer.workspace,
# input_dtype=self.input_dtype,
# inplace=not self.moe.disable_inplace,
# )
num_tokens, num_experts = router_logits.shape
assert layer.activation.value == "silu", "Only SiLU activation is supported."
use_ep = layer.expert_map is not None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata:
if isinstance(attn_metadata, dict):
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
else:
only_decode = use_ep == False and attn_metadata.num_decodes > 0 and attn_metadata.num_prefills == 0
else:
only_decode = False
if use_ep:
start_eid = layer.ep_rank * layer.local_num_experts
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
num_tokens = topk_ids.shape[0]
num_experts = layer.global_num_experts
if use_ep:
hidden_size = x.shape[1]
(
@@ -875,7 +862,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
dtype=x.dtype,
)
else:
expand_tokens = num_tokens * top_k
expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
@@ -885,7 +872,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
num_experts=num_experts,
)
expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder
# TODO use kernel
@@ -893,76 +879,130 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
topk=top_k,
topk=layer.top_k,
src_to_dst=src_to_dst,
)
# w4a16 group gemm 1
# pt_output_1: (expand_tokens, 2n) dtype
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
if only_decode:
pt_output_1 = ixfops.moe_w4a16_group_gemv(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
topk_weight=topk_weights,
scaling_factor=routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
pt_output_3 = ixfops.moe_w4a16_group_gemv(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=routed_scaling_factor
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
else:
expert_sizes_cpu = expert_sizes_gpu.cpu()
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * layer.top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
tokens_per_experts_gpu=expert_sizes_gpu,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
return final_hidden_states
# return torch.ops.vllm.fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# layer.w13_scales,
# layer.w2_scales,
# router_logits,
# topk_weights,
# topk_ids,
# w1_zeros=layer.w13_qzeros,
# w2_zeros=layer.w2_qzeros,
# num_bits=self.quant_config.weight_bits,
# )

View File

@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
)
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -52,7 +51,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16,
CompressedTensorsW4A8Int8
)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod,
@@ -401,8 +399,8 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_tensor = (
weight_strategy
@@ -420,8 +418,8 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_token = (
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
@@ -663,12 +661,6 @@ class CompressedTensorsConfig(QuantizationConfig):
)
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
if envs.VLLM_W8A8_LINEAR_USE_W4A8:
return CompressedTensorsW4A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric,
)
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False,

View File

@@ -8,7 +8,7 @@ from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8, CompressedTensorsW4A8Int8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
@@ -28,5 +28,4 @@ __all__ = [
"CompressedTensorsW4A4Fp4",
"CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8",
"CompressedTensorsW4A8Int8"
]

View File

@@ -25,11 +25,18 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool, is_w4a8_linear: bool = False
):
self.strategy = strategy
import vllm.envs as env
if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR":
self.strategy = QuantizationStrategy.TENSOR
elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL":
self.strategy = QuantizationStrategy.CHANNEL
else:
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
self.is_w4a8_linear = is_w4a8_linear
@classmethod
def get_min_capability(cls) -> int:
@@ -53,16 +60,32 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric=self.input_symmetric,
module_name=self.__class__.__name__,
)
remainder = input_size_per_partition % 64
if remainder != 0:
input_size_per_partition_padded = input_size_per_partition + (64 - remainder)
else:
input_size_per_partition_padded = input_size_per_partition
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.is_w4a8_linear:
# only "NN" is supported
weight = ModelWeightParameter(data=torch.empty(
input_size_per_partition_padded,
sum(output_partition_sizes) // 2,
dtype=torch.int8),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
else:
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition_padded,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
@@ -109,104 +132,4 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
class CompressedTensorsW4A8Int8(CompressedTensorsScheme):
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
layer.logical_widths = output_partition_sizes
self.kernel = init_int8_linear_kernel(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric,
module_name=self.__class__.__name__,
)
# WEIGHT
# weight = ModelWeightParameter(
# data=torch.empty(
# sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
# ),
# input_dim=1,
# output_dim=0,
# weight_loader=weight_loader,
# )
weight = ModelWeightParameter(
data=torch.empty(
input_size_per_partition,
sum(output_partition_sizes) // 2,
dtype=torch.int8
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
input_zero_point = None
input_scale = None
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_scale", input_scale)
if not hasattr(layer, "azp_adj"):
layer.register_parameter("azp_adj", None)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
return self.kernel.apply_weights(layer, x, bias, self.is_w4a8_linear)

View File

@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported,
MoEActivation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel(
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# TRTLLM does not use Modular Kernel.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
assert layer.activation == MoEActivation.SILU, (
f"Expected 'silu' activation but got {layer.activation}"
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv,
w2_weight=layer.w2_weight,
w2_weight_scale_inv=layer.w2_weight_scale_inv,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor,
)
else:
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
def apply(
self,
layer: FusedMoE,
@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.moe_mk is not None
assert not self.is_monolithic
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,

View File

@@ -7,6 +7,7 @@ from typing import Any
import gguf
import torch
import torch.nn.functional as F
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
@@ -234,7 +235,7 @@ try:
op_func=_fused_mul_mat_gguf,
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
fused_mul_mat_gguf = _fused_mul_mat_gguf
except AttributeError as error:
raise error
@@ -365,7 +366,7 @@ try:
op_func=_fused_moe_gguf,
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
fused_moe_gguf = _fused_moe_gguf
except AttributeError as error:
raise error
@@ -410,7 +411,7 @@ try:
op_func=_apply_gguf_embedding,
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
apply_gguf_embedding = _apply_gguf_embedding
except AttributeError as error:
raise error
@@ -451,6 +452,9 @@ class GGUFLinearMethod(LinearMethodBase):
"data_container": [],
"shard_id": [],
"shard_id_map": {},
"params_dtype": params_dtype,
"input_size_per_partition" :input_size_per_partition, # restore shape for qkv and merge
"output_partition_sizes" :output_partition_sizes,
},
)
set_weight_attrs(qweight, extra_weight_attrs)
@@ -664,6 +668,10 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
"""
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
weight = layer.weight
return F.embedding(x, weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]

View File

@@ -128,7 +128,7 @@ class GPTQConfig(QuantizationConfig):
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
return [torch.bfloat16, torch.half]
@classmethod
# Need to figure it out

View File

@@ -59,9 +59,164 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils.collection_utils import is_list_of
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
#[B,K//8,N] ->[B,K,N]
# less memmory
def unpack_k_batch_opt(packed_w: torch.Tensor, num_bits: int = 4, chunk_size: int = 2) -> torch.Tensor:
"""
Memory-efficient unpacking for 3D tensor.
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor,
without broadcasting huge intermediate tensors (avoids OOM).
Args:
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
num_bits: Number of bits per packed element (e.g., 4 or 2).
chunk_size: How many bit groups to unpack at once (tradeoff between speed and memory).
Returns:
unpacked: torch.int8 tensor of shape [B, K, N].
"""
B, k_packed, N = packed_w.shape
pack_factor = 32 // num_bits
K = k_packed * pack_factor
mask = (1 << num_bits) - 1
# Allocate output tensor once
unpacked = torch.empty((B, K, N), dtype=torch.int8, device=packed_w.device)
# Process bit chunks iteratively to save memory
for i in range(0, pack_factor, chunk_size):
# Precompute shifts for this chunk
shift_vals = num_bits * torch.arange(i, min(i + chunk_size, pack_factor), device=packed_w.device)
# [chunk_size, 1, 1, 1]
shifts = shift_vals.view(-1, 1, 1, 1)
# Compute small chunk only
chunk = ((packed_w.unsqueeze(0) >> shifts) & mask).to(torch.int8)
# chunk: [chunk_size, B, k_packed, N]
# write into output
for j in range(chunk.shape[0]):
unpacked[:, (i + j)::pack_factor, :] = chunk[j]
del chunk # release memory early
return unpacked
# more memmory
def unpack_k_batch(packed_w: torch.Tensor, num_bits: int = 4) -> torch.Tensor:
"""
Efficient vectorized unpacking for 3D tensor (batch version).
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor.
Args:
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
num_bits: Number of bits per packed element (e.g., 4).
Returns:
unpacked: torch.int8 tensor of shape [B, K, N].
"""
B, k_packed, n = packed_w.shape
pack_factor = 32 // num_bits
k = k_packed * pack_factor
mask = (1 << num_bits) - 1
# [pack_factor, 1, 1, 1]
shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1, 1)
# [1, B, k_packed, N]
packed_expanded = packed_w.unsqueeze(0)
# Extract each group of num_bits using bitwise ops
unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
# [pack_factor, B, k_packed, N] → [B, K, N]
unpacked = unpacked_groups.permute(1, 2, 0, 3).reshape(B, k, n)
return unpacked
#[B,K,N] ->[B,K,N//8]
# less memmory
def pack_n_batch_opt(x: torch.Tensor, pack_num: int = 8, order_map=None, chunk_size: int = 2) -> torch.Tensor:
"""
Memory-efficient batch packing with correct bit order.
[B, K, N] int4 -> [B, K, N//pack_num] int32.
"""
B, K, N = x.shape
assert N % pack_num == 0, "N must be divisible by pack_num"
cols = N // pack_num
unit = 32 // pack_num
if order_map is None:
order_map = list(range(pack_num))
order_map = torch.tensor(order_map, device=x.device)
shifts = unit * torch.arange(pack_num, device=x.device) # always 0..unit*(pack_num-1)
packed = torch.zeros((B, K, cols), dtype=torch.int32, device=x.device)
x_reshape = x.view(B, K, cols, pack_num) & 0xF
# process in chunks for memory efficiency
for start in range(0, pack_num, chunk_size):
end = min(start + chunk_size, pack_num)
idx_chunk = order_map[start:end]
shift_chunk = shifts[start:end]
vals = torch.gather(x_reshape, 3, idx_chunk.view(1,1,1,-1).expand(B,K,cols,-1)).to(torch.int32)
for j in range(vals.shape[-1]):
packed.add_(vals[..., j] << shift_chunk[j])
return packed
## more memmory
def pack_n_batch(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
"""
Efficient vectorized batch packing: [B, K, N] int4 -> [B, K, N//pack_num] int32.
Args:
x: torch.int32 tensor of shape [B, K, N], each element 0-15 (int4).
pack_num: Number of 4-bit elements per packed int32 (default=8).
order_map: Optional order of elements within each packed int32.
Returns:
torch.int32 tensor of shape [B, K, N//pack_num].
"""
B, K, N = x.shape
assert N % pack_num == 0, "N must be divisible by pack_num"
cols = N // pack_num
if order_map is None:
order_map = list(range(pack_num))
order_map = torch.tensor(order_map, device=x.device)
unit = 32 // pack_num # number of bits per element
# reshape to [B, K, cols, pack_num]
pack_num_int = int(pack_num)
x_reshape = x.view(B, K, cols, pack_num_int)
# reorder according to order_map
x_reorder = torch.gather(
x_reshape, 3, order_map.view(1, 1, 1, -1).expand(B, K, cols, -1)
)
# mask low 4 bits
x_reorder = x_reorder & 0xF
# bit shifts [pack_num] -> [1,1,1,pack_num] broadcastable
shifts = (unit * torch.arange(pack_num_int, device=x.device)).view(1, 1, 1, -1)
# shift and sum along last dimension to combine bits
packed = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
return packed
def get_moe_quant_method(
config: "GPTQMarlinConfig",
@@ -495,8 +650,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
# elif self.quant_config.quant_type.size_bits == 8:
# self.quant_type = scalar_types.uint8b128
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
@@ -594,7 +749,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_experts,
scales_size13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=params_dtype,
dtype=torch.int32,
),
requires_grad=False,
)
@@ -606,7 +761,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype,
dtype=torch.int32,
),
requires_grad=False,
)
@@ -656,7 +811,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace_new(device, 4)
# layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
@@ -673,119 +828,111 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order
if self.quant_config.desc_act:
# if self.quant_config.desc_act:
# Get sorting based on g_idx
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
torch.int32
)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32
)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# num_experts = layer.w13_g_idx.shape[0]
# w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
# w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
# w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
# w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
# for e in range(num_experts):
# w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
# torch.int32
# )
# w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
# torch.int32
# )
# w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
# w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
# replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
# replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
# replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
# replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
# else:
# # Reset g_idx related tensors
# num_experts = layer.w13_g_idx.shape[0]
# device = layer.w13_g_idx.device
# layer.w13_g_idx = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w2_g_idx = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w13_g_idx_sort_indices = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# layer.w2_g_idx_sort_indices = torch.nn.Parameter(
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
# requires_grad=False,
# )
# # Repack weights
# marlin_w13_qweight = ops.gptq_marlin_moe_repack(
# layer.w13_qweight,
# layer.w13_g_idx_sort_indices,
# layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
# layer.w13_qweight.shape[2],
# self.quant_config.quant_type.size_bits,
# )
# replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
# marlin_w2_qweight = ops.gptq_marlin_moe_repack(
# layer.w2_qweight,
# layer.w2_g_idx_sort_indices,
# layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
# layer.w2_qweight.shape[2],
# self.quant_config.quant_type.size_bits,
# )
# replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# # Repack scales
# marlin_w13_scales = marlin_moe_permute_scales(
# s=layer.w13_scales,
# size_k=layer.intermediate_size_per_partition,
# size_n=layer.w13_scales.shape[2],
# group_size=self.quant_config.group_size,
# )
# replace_parameter(layer, "w13_scales", marlin_w13_scales)
# marlin_w2_scales = marlin_moe_permute_scales(
# s=layer.w2_scales,
# size_k=layer.w2_scales.shape[1]
# * (
# self.quant_config.group_size
# if self.quant_config.group_size != -1
# else self.quant_config.pack_factor
# ),
# size_n=layer.w2_scales.shape[2],
# group_size=self.quant_config.group_size,
# )
# replace_parameter(layer, "w2_scales", marlin_w2_scales)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
# layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1]
* (
self.quant_config.group_size
if self.quant_config.group_size != -1
else self.quant_config.pack_factor
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
# if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
# layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
if self.quant_config.desc_act:
raise NotImplementedError(
"GPTQMarlinMoEMethod now not support desc_act. please fix it")
w13_qweight_unpacked = unpack_k_batch(layer.w13_qweight)
w13_qweight_repacked = pack_n_batch(w13_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
replace_parameter(layer, "w13_qweight", w13_qweight_repacked)
# quant vllm/model_executor/layers/quantization/utils/quant_utils.py#quantize_weights
# if quant_type.has_bias():
# w_q += quant_type.bias
# use quant_type.bias as zp,(ixformer support)
w13_zp = torch.full_like(layer.w13_scales, self.quant_type.bias, dtype=torch.int32)
w13_zp_pack = pack_n_batch(w13_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
replace_parameter(layer, "w13_qzeros", w13_zp_pack)
w2_qweight_unpacked = unpack_k_batch(layer.w2_qweight)
w2_qweight_repacked = pack_n_batch(w2_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
replace_parameter(layer, "w2_qweight", w2_qweight_repacked)
w2_zp = torch.full_like(layer.w2_scales, self.quant_type.bias, dtype=torch.int32)
w2_zp_pack = pack_n_batch(w2_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
replace_parameter(layer, "w2_qzeros", w2_zp_pack)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -900,30 +1047,165 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
getattr(layer, "w13_bias", None),
getattr(layer, "w2_bias", None),
layer.w13_scales,
layer.w2_scales,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
assert layer.activation.value == "silu", "Only SiLU activation is supported."
use_ep = layer.expert_map is not None
if use_ep:
start_eid = layer.ep_rank * layer.local_num_experts
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"GPTQMarlinMoEMethod Apply router weight on input is not supported for"
"fused Marlin MoE method.")
if (hasattr(layer, "w13_bias") and layer.w13_bias is not None) or (hasattr(layer, "w2_bias") and layer.w2_bias is not None):
raise NotImplementedError(
"GPTQMarlinMoEMethod moe_w4a16_group_gemm not supported bias, please fix this")
num_tokens = topk_ids.shape[0]
num_experts = layer.global_num_experts
if use_ep:
hidden_size = x.shape[1]
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
expand_tokens,
) = ixfops.moe_compute_token_index_ep(
topk_ids=topk_ids,
num_experts=num_experts,
start_expert_id=start_eid,
end_expert_id=end_eid,
)
if expert_sizes_cpu.sum() == 0:
return torch.zeros(
(num_tokens, hidden_size),
device=x.device,
dtype=x.dtype,
)
else:
expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
) = ixfops.moe_compute_token_index(
topk_ids=topk_ids,
num_experts=num_experts,
)
expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder
# TODO use kernel
expand_hidden_states = ixfops.moe_expand_input(
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
topk=layer.top_k,
src_to_dst=src_to_dst,
)
# w4a16 group gemm 1
# pt_output_1: (expand_tokens, 2n) dtype
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * layer.top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
tokens_per_experts_gpu=expert_sizes_gpu,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
return final_hidden_states
# return torch.ops.vllm.fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# getattr(layer, "w13_bias", None),
# getattr(layer, "w2_bias", None),
# layer.w13_scales,
# layer.w2_scales,
# router_logits,
# topk_weights,
# topk_ids,
# quant_type_id=self.quant_type.id,
# apply_router_weight_on_input=apply_router_weight_on_input,
# global_num_experts=global_num_experts,
# expert_map=expert_map,
# g_idx1=layer.w13_g_idx,
# g_idx2=layer.w2_g_idx,
# sort_indices1=layer.w13_g_idx_sort_indices,
# sort_indices2=layer.w2_g_idx_sort_indices,
# workspace=layer.workspace,
# is_k_full=self.is_k_full)

View File

@@ -12,8 +12,7 @@ from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
process_fp8_input_tensor_strategy_moe,
@@ -114,6 +104,8 @@ QUANT_ALGOS = [
"NVFP4",
# MXFP8
"MXFP8",
# MIXED_PRECISION,
"MIXED_PRECISION",
]
KV_CACHE_QUANT_ALGOS = ["FP8"]
@@ -181,7 +173,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
# handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention):
if isinstance(layer, (Attention, MLAAttention)):
return self.KVCacheMethodCls(self)
# handle exclusion
@@ -235,6 +227,26 @@ class ModelOptQuantConfigBase(QuantizationConfig):
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
@staticmethod
def _extract_modelopt_quant_algo(
hf_quant_cfg: dict[str, Any] | None,
) -> str | None:
"""Extract upper-cased quant_algo from a modelopt config.
Returns the quant_algo string (upper-cased), or None if the config
is not a modelopt config.
"""
if hf_quant_cfg is None:
return None
if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
return None
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
return str(quant_config.get("quant_algo", "")).upper()
return None
return str(hf_quant_cfg.get("quant_algo", "")).upper()
@staticmethod
def get_config_filenames() -> list[str]:
return ["hf_quant_config.json"]
@@ -272,10 +284,20 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# "exclude_modules" is the key in the legacy hf_quant_config.json
exclude_modules = quant_config.get("exclude_modules", [])
else:
# Compressed-tensors style format:
# Compressed-tensors style format (config.json quantization_config):
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo")
kv_cache_quant_method = config.get("kv_cache_quant_algo")
# "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
kv_cache_scheme = config.get("kv_cache_scheme")
if isinstance(kv_cache_scheme, dict) and (
kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_method = "FP8"
else:
kv_cache_quant_method = None
# "ignore" is the key in config.json
exclude_modules = config.get("ignore", [])
group_size_raw = config.get("group_size")
@@ -379,32 +401,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""Detect if this ModelOpt config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", ""))
if quant_algo.upper() == "FP8":
return "modelopt"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
if quant_algo.upper() == "FP8":
return "modelopt"
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "FP8":
return "modelopt"
return None
@classmethod
@@ -737,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -745,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -862,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13 = layer.w13_weight
@@ -904,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
a1_scale = layer.w13_input_scale
@@ -920,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
a2_scale=a2_scale,
)
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
layer: FusedMoE,
@@ -931,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
)
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert layer.activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {layer.activation} found instead."
)
return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
top_k=layer.top_k,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
@@ -964,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# TODO(rob): this validation should happen at kernel selection
# time in the oracle rather than here.
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
assert layer.activation in (
MoEActivation.SILU,
MoEActivation.RELU2_NO_MUL,
), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {layer.activation}"
)
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
@@ -1031,32 +1003,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""Detect if this ModelOpt FP4 config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = quant_config.get("quant_algo", "")
if "NVFP4" in quant_algo:
return "modelopt_fp4"
else:
# Check for compressed-tensors style config with specific
# quant_algo field
quant_algo = hf_quant_cfg.get("quant_algo", "")
if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
return "modelopt_fp4"
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and ("NVFP4" in algo or "FP4" in algo):
return "modelopt_fp4"
return None
@classmethod
@@ -1249,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -1434,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale", a2_scale)
# Setup modular kernel for TP case and naive DP/EP case.
# In non-naive DP/EP case, we will create a ModularKernelMethod.
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
@property
def do_post_quant_allgather(self):
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
raise RuntimeError(
"prepare_dp_allgather_tensor is only supported for "
"FlashInfer TRTLLM NVFP4 MoE backend."
)
import flashinfer
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
layer.a1_gscale,
is_sf_swizzled_layout=False,
assert self.experts_cls is not None
self.moe_kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
shared_experts=layer.shared_experts,
routing_tables=layer._maybe_init_expert_routing_tables(),
)
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
@@ -1493,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not self.moe.moe_parallel_config.enable_eplb
)
def apply_monolithic(
self,
layer: FusedMoE,
@@ -1507,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert (
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
and not layer.enable_eplb
)
return flashinfer_trtllm_fp4_moe(
layer=layer,
x=x,
router_logits=router_logits,
top_k=layer.top_k,
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
@@ -1534,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
# EPLB path
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
assert layer.enable_eplb
return flashinfer_trtllm_fp4_routed_moe(
layer=layer,
x=x,
topk_ids=topk_ids,
topk_weights=topk_weights,
top_k=layer.top_k,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
)
else:
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
@@ -1619,31 +1502,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""Detect if this ModelOpt MXFP8 config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and "MXFP8" in algo:
return "modelopt_mxfp8"
return None
@classmethod
@@ -1841,3 +1702,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
"""Config class for ModelOpt MIXED_PRECISION.
Supports checkpoints where different layers use different quantization
algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
The per-layer algorithm is specified in the ``quantized_layers`` dict
inside ``config.json``'s ``quantization_config`` (preferred) or the
legacy ``hf_quant_config.json``.
"""
def __init__(
self,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
quantized_layers: dict[str, dict[str, Any]],
fp8_config: ModelOptFp8Config,
nvfp4_config: ModelOptNvFp4Config,
) -> None:
super().__init__(exclude_modules)
self.kv_cache_quant_method = kv_cache_quant_method
self.quantized_layers = quantized_layers
self.fp8_config = fp8_config
self.nvfp4_config = nvfp4_config
def get_name(self) -> QuantizationMethods:
return "modelopt_mixed"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
if algo is not None and algo == "MIXED_PRECISION":
return "modelopt_mixed"
return None
@classmethod
def _from_config(
cls,
*,
quant_method: str,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
original_config: dict[str, Any],
group_size: int | None,
**kwargs: Any,
) -> "ModelOptMixedPrecisionConfig":
if "quantization" in original_config:
quantized_layers = original_config["quantization"].get(
"quantized_layers", {}
)
else:
quantized_layers = original_config.get("quantized_layers", {})
if not quantized_layers:
raise ValueError(
"MIXED_PRECISION quant_algo requires a non-empty "
"'quantized_layers' mapping in the quantization config."
)
# Determine group_size from the first NVFP4 entry if not provided.
if group_size is None:
for layer_info in quantized_layers.values():
if layer_info.get("quant_algo", "").upper() == "NVFP4":
group_size = layer_info.get("group_size", 16)
break
if group_size is None:
group_size = 16
fp8_config = ModelOptFp8Config(
quant_method="FP8",
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=[],
)
nvfp4_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=kv_cache_quant_method,
exclude_modules=[],
group_size=group_size,
)
return cls(
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
quantized_layers=quantized_layers,
fp8_config=fp8_config,
nvfp4_config=nvfp4_config,
)
def _resolve_quant_algo(self, prefix: str) -> str | None:
"""Look up the quant_algo for a vLLM-side layer prefix.
Tries three strategies in order:
1. Direct lookup in ``quantized_layers``.
2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
3. Prefix-based lookup for FusedMoE (any child key starts with
``prefix + "."``).
Returns the upper-cased quant_algo string, or *None* if the prefix
is not found.
"""
# 1. Direct lookup
if prefix in self.quantized_layers:
return self.quantized_layers[prefix]["quant_algo"].upper()
# 2. Packed / fused layer lookup
proj_name = prefix.rsplit(".", 1)[-1]
if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
algos: set[str] = set()
base = prefix.rsplit(".", 1)[0]
for shard_name in self.packed_modules_mapping[proj_name]:
shard_prefix = f"{base}.{shard_name}"
if shard_prefix in self.quantized_layers:
algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
if len(algos) == 1:
return algos.pop()
if len(algos) > 1:
raise ValueError(
f"Mixed quant_algo within fused layer {prefix}: "
f"{algos}. All shards must use the same quantization."
)
# 3. Prefix-based lookup (for FusedMoE / parent modules)
prefix_dot = prefix + "."
for key, info in self.quantized_layers.items():
if key.startswith(prefix_dot):
return info["quant_algo"].upper()
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
"""Return quantize-method based on layer."""
# KV-cache quantization
if isinstance(layer, Attention):
if self.kv_cache_quant_method:
return ModelOptFp8KVCacheMethod(self)
return None
# Excluded layers
if self.is_layer_excluded(prefix):
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
return None
quant_algo = self._resolve_quant_algo(prefix)
if isinstance(layer, LinearBase):
if quant_algo == "FP8":
return ModelOptFp8LinearMethod(self.fp8_config)
if quant_algo == "NVFP4":
return ModelOptNvFp4LinearMethod(self.nvfp4_config)
# Layer not in quantized_layers — leave unquantized
return UnquantizedLinearMethod()
if isinstance(layer, FusedMoE):
if quant_algo == "FP8":
return ModelOptFp8MoEMethod(
quant_config=self.fp8_config,
moe_config=layer.moe_config,
)
if quant_algo == "NVFP4":
return ModelOptNvFp4FusedMoE(
quant_config=self.nvfp4_config,
moe_config=layer.moe_config,
)
return None
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
super().apply_vllm_mapper(hf_to_vllm_mapper)
if self.quantized_layers:
self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)

View File

@@ -6,6 +6,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
@@ -77,6 +78,8 @@ class Mxfp4Backend(Enum):
# Triton Backend
TRITON = 6
CK = 7
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
"""
@@ -167,9 +170,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
elif current_platform.is_xpu():
logger.info_once("Using xpu backend on XPU")
return Mxfp4Backend.MARLIN
elif current_platform.is_rocm() and has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx950
if rocm_aiter_ops.is_enabled() and on_gfx950():
logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
return Mxfp4Backend.CK
elif has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
return Mxfp4Backend.NONE
@@ -257,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_mk: mk.FusedMoEModularKernel | None = None
self.moe_kernel: mk.FusedMoEKernel | None = None
def create_weights(
self,
@@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
self.intermediate_pad = (
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
)
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
@@ -427,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
MarlinExperts(
self.moe,
@@ -776,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
FlashInferExperts(
moe_config=self.moe,
@@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
),
shared_experts=None,
)
elif self.mxfp4_backend == Mxfp4Backend.CK:
if layer.w13_bias is not None:
layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
if layer.w2_bias.data is not None:
layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)
e, n, k = layer.w13_weight.shape
layer.w13_weight.view(torch.uint8).copy_(
layer.w13_weight.data.view(torch.uint8)
.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
layer.w13_weight_scale.data = (
layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)
layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)
layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w13_weight, 16, True
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
self.num_experts,
True,
)
layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w2_weight, 16, False
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
self.num_experts,
False,
)
layer.w13_bias.data = (
layer.w13_bias.data.view(-1, n // 2, 2)
.permute(0, 2, 1)
.contiguous()
.view(-1, n)
)
layer.w13_weight_scale = torch.nn.Parameter(
shuffled_w13_scale, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
shuffled_w2_scale, requires_grad=False
)
# replace_parameter(layer, "w13_bias", w13_bias)
# replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
# replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
# replace_parameter(layer, "w13_weight", w13_weight)
# replace_parameter(layer, "w2_weight", w2_weight)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
@@ -792,18 +865,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
is_batched_moe = self.moe.use_deepep_ll_kernels
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
@@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight = w13_weight
self.w2_weight = w2_weight
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else:
raise ValueError(
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
@@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_BF16,
Mxfp4Backend.SM90_FI_MXFP4_BF16,
Mxfp4Backend.CK,
]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
@@ -882,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
@@ -929,10 +1001,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
@property
def is_monolithic(self) -> bool:
if self.moe.is_lora_enabled:
return False
return (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON
or self.mxfp4_backend == Mxfp4Backend.CK
)
def apply(
@@ -968,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.MARLIN
)
assert self.moe_mk is not None
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -1054,6 +1129,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
return trtllm_gen_output
elif self.mxfp4_backend == Mxfp4Backend.CK:
topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
x, router_logits, layer.top_k, True
)
output = rocm_aiter_ops.fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
doweight_stage1=False,
hidden_pad=self.hidden_pad // 128 * 128,
intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
bias1=layer.w13_bias,
bias2=layer.w2_bias,
)
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
@@ -1162,7 +1258,7 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
topk_weights=routing_weights,
topk_ids=selected_experts,
n_experts_per_token=layer.top_k,
activation=layer.activation,
activation=layer.activation.value,
num_experts=layer.local_num_experts,
is_mxfp4=True,
)

View File

@@ -7,7 +7,6 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
@@ -26,10 +25,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
class PTPCFp8Config(Fp8Config):
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""

View File

@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
self.dynamic_mxfp4_quant = False
def maybe_update_config(self, model_name: str, revision: str | None = None):
self.hf_config = get_config(
model=model_name,
trust_remote_code=False, # or get from model_config if available
revision=revision,
config_format="auto",
)
quant_config = getattr(self.hf_config, "quantization_config", None)
if quant_config is not None:
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
model_type = self.hf_config.model_type
if quant_dtype == "fp4" and model_type == "deepseek_v3":
self.dynamic_mxfp4_quant = True
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
if should_ignore_layer(
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if (
"self_attn" not in prefix # only quantize attention projections
or not getattr(self, "dynamic_mxfp4_quant", False)
or not isinstance(layer, LinearBase) # Ignore other methods
):
return UnquantizedLinearMethod()
scheme = self.get_scheme(
layer=layer,
layer_name=prefix,
dynamic_mxfp4_quant=True,
)
layer.scheme = scheme
return QuarkLinearMethod(self)
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
)
return global_quant_config
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
def _get_scheme_from_config(
self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
input_symmetric=input_config.get("symmetric"),
)
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(weight_config, input_config)
return QuarkOCP_MX(
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
)
raise NotImplementedError(
"No quark compatible scheme was found. "
@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
f"Input config: {input_config}"
)
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
def get_scheme(
self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme
scheme = self._get_scheme_from_config(layer_quant_config)
scheme = self._get_scheme_from_config(
layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())

View File

@@ -5,8 +5,8 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
__all__ = [
"QuarkMoEMethod",
"QuarkOCP_MX_MoEMethod",
"QuarkOCP_MX_MoEMethod_OSS",
]
class QuarkMoEMethod(FusedMoEMethodBase):
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
"output_tensors and bias "
"quantized are not supported"
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w4a8(weight_config, input_config):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
emulate = not current_platform.supports_mx() or not (
rocm_aiter_ops.is_fused_moe_enabled()
)
if (
input_config.get("dtype") == "fp8_e4m3"
and not input_config.get("is_dynamic")
and not emulate
):
return QuarkOCP_MX_MoEMethod_OSS(
weight_config, input_config, module.moe_config
)
else:
return QuarkOCP_MX_MoEMethod(
weight_config, input_config, module.moe_config
)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
get_current_vllm_config().model_config.hf_config, "model_type", None
)
self._emulate = (
self.emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
params_dtype = torch.uint8
self.intermediate_size_per_partition = intermediate_size_per_partition
if self.model_type == "gpt_oss":
if current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
else:
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
self.unpadded_hidden_size = extra_weight_attrs.get(
"unpadded_hidden_size", hidden_size
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
if (
self.model_type == "gpt_oss"
and self.mxfp4_backend == Mxfp4Backend.TRITON
):
raise NotImplementedError(
"Triton kernel implemented fused MoE for GPT_OSS model "
"in Quark(MoE) format is not integrated or provided yet."
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
else:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(weight_config, input_config, moe)
def process_weights_after_loading(self, layer):
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
if self.static_input_scales:
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
)
from triton_kernels.numerics import InFlexData
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale,
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale,
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return mxfp4_w4a8_moe_quant_config(
w1_scale=self.w13_precision_config,
w2_scale=self.w2_precision_config,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
block_shape=None,
)
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
)
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
gating_output=router_logits,
topk=layer.top_k,
renormalize=layer.renormalize,
global_num_experts=layer.global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
unpadded_N_w1=self.intermediate_size_per_partition * 2,
unpadded_K_w1=self.unpadded_hidden_size,
unpadded_N_w2=self.unpadded_hidden_size,
unpadded_K_w2=self.intermediate_size_per_partition,
)

View File

@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
)
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PackedvLLMParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from .quark_scheme import QuarkScheme
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
class QuarkOCP_MX(QuarkScheme):
def __init__(
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
self,
weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any],
dynamic_mxfp4_quant: bool = False,
):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
layer.weight_scale.data, requires_grad=False
)
else:
if self.rocm_use_aiter_fp4_asm_gemm:
if self.dynamic_mxfp4_quant:
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
layer.weight_scale = torch.nn.Parameter(
w_s.T.contiguous(), requires_grad=False
)
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
elif self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
if self.dynamic_mxfp4_quant:
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, kwargs)
else:
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(
self,

View File

@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
align_fp4_moe_weights_for_fi,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -42,92 +32,15 @@ __all__ = [
"reorder_w1w3_to_w3w1",
]
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""Supports non-gated MoE."""
return True
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Nvfp4 quantization."""
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
def _supports_routing_method(
routing_method: RoutingMethodType,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
]
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""
TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
the naive dispatch/combine path. DeepEP HT only implements dispatch() for
the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
"""
return not moe_parallel_config.use_deepep_ht_kernels
def is_supported_config_trtllm(
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_quant_scheme(weight_key, activation_key):
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method(moe_config.routing_method):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
elif moe_config.hidden_dim % 512 != 0:
return False, _make_reason(
f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
)
return True, None
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)
def reorder_w1w3_to_w3w1(
@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
)
def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
activation: MoEActivation,
global_num_experts: int,
num_expert_group: int | None,
topk_group: int | None,
custom_routing_function: object | None,
e_score_correction_bias: torch.Tensor | None,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
router_logits: Router logits for expert selection
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
num_expert_group: Number of expert groups (for grouped routing)
topk_group: Top-k within each group
custom_routing_function: Custom routing function (e.g., Llama4)
e_score_correction_bias: Optional routing bias correction
Returns:
Output tensor from the MoE layer
"""
import flashinfer
from vllm.model_executor.models.llama4 import Llama4MoE
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert activation in SUPPORTED_ACTIVATIONS, (
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
f"TRTLLM FP4 MoE, {activation} found instead."
)
# Quantize input to FP4
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
routing_method_type = layer.routing_method_type
if use_llama4_routing:
routing_method_type = flashinfer.RoutingMethodType.Llama4
# Cast to Fp32 (required by kernel).
router_logits = (
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
)
# Determine activation type
activation_type = activation_to_flashinfer_int(layer.activation)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).reshape(*hidden_states_fp4.shape[:-1], -1),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=num_expert_group if num_expert_group is not None else 0,
topk_group=topk_group if topk_group is not None else 0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=routing_method_type,
do_finalize=True,
activation_type=activation_type,
)[0]
return out
def flashinfer_trtllm_fp4_routed_moe(
layer: torch.nn.Module,
x: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
activation: MoEActivation,
global_num_experts: int,
) -> torch.Tensor:
"""
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
input top k expert indices and scores rather than computing
top k expert indices from scores.
Args:
layer: The MoE layer with weights and scales
x: Input tensor
topk_ids: Ids of selected experts
top_k: Number of experts to select per token
activation: Activation function to use
global_num_experts: Total number of experts across all ranks
Returns:
Output tensor from the MoE layer
"""
import flashinfer
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
assert activation == MoEActivation.SILU, (
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
f"{activation} found instead."
)
# Pack top k ids and expert weights into a single int32 tensor, as
# required by TRT-LLM
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
if isinstance(x, tuple):
# Hidden_states is the already quantized
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_tensor,
routing_bias=None,
hidden_states=hidden_states_fp4,
hidden_states_scale=hidden_states_scale_linear_fp4.view(
torch.float8_e4m3fn
).reshape(*hidden_states_fp4.shape[:-1], -1),
gemm1_weights=layer.w13_weight.data,
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=layer.w2_weight.data,
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
gemm2_bias=None,
output1_scale_scalar=layer.g1_scale_c.data,
output1_scale_gate_scalar=layer.g1_alphas.data,
output2_scale_scalar=layer.g2_alphas.data,
num_experts=global_num_experts,
top_k=top_k,
n_group=0,
topk_group=0,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
do_finalize=True,
)[0]
return out
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend: "NvFp4MoeBackend",
layer: "FusedMoE",
@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
)
)
layer.intermediate_size_per_partition = padded_intermediate
layer.moe_config.intermediate_size_per_partition = padded_intermediate
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13,

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import TYPE_CHECKING
import torch
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from flashinfer.fused_moe.core import ActivationType
logger = init_logger(__name__)
@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
return activation_to_flashinfer_type(activation).value
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
from flashinfer.fused_moe.core import ActivationType
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
MoEActivation.GELU: ActivationType.Geglu,
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
}
return ACTIVATION_TO_FI_ACTIVATION[activation].value
return ACTIVATION_TO_FI_ACTIVATION[activation]
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
)
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
layer.output1_scales_gate_scalar = g1_alphas
if layer.activation.is_gated:
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
else:
layer.output1_scales_scalar = (
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
)
layer.output2_scales_scalar = g2_alphas
def apply_fi_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
top_k: int,
num_expert_group: int | None,
topk_group: int | None,
global_num_experts: int,
apply_router_weight_on_input: bool,
) -> torch.Tensor:
from flashinfer.fused_moe import RoutingMethodType
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from vllm.model_executor.models.llama4 import Llama4MoE
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
if layer.routing_method_type == RoutingMethodType.Llama4:
assert (
not layer.renormalize
and layer.custom_routing_function == Llama4MoE.custom_routing_function
), (
"FusedMoE flashinfer kernels with Llama4 routing method are only "
"supported for Llama4"
)
else:
assert layer.custom_routing_function is None, (
"Custom routing function is only supported for Llama4"
)
activation_type = activation_to_flashinfer_int(layer.activation)
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
input_scale=layer.w13_input_scale,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
output2_scales_scalar=layer.output2_scales_scalar,
num_experts=global_num_experts,
top_k=top_k,
num_expert_group=num_expert_group,
topk_group=topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=layer.routing_method_type,
activation_type=activation_type,
)
def make_fp8_moe_alpha_scales_for_fi(
w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
g1_alphas = (w13_scale * w13_input_scale).squeeze()
g2_alphas = (w2_scale * w2_input_scale).squeeze()
return g1_alphas, g2_alphas
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
min_alignment,
)
layer.intermediate_size_per_partition = new_intermediate
layer.moe_config.intermediate_size_per_partition = new_intermediate
# FI kernels require W31 layout rather than W13.
if layer.moe_config.is_act_and_mul:
@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
w13_scale = swap_w13_to_w31(w13_scale)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
# and registration of alpha scales. Note that we do not register
# as nn.Parameters since they are not needed for weight-reloading.
# and registration of alpha scales.
if is_trtllm and not block_quant:
assert w13_input_scale is not None
assert w2_input_scale is not None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused

View File

@@ -53,7 +53,10 @@ logger = init_logger(__name__)
def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
if isinstance(x, torch.Tensor):
x = x.dtype
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
try:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
except:
return False
# We need to pass in the is_hopper flag as argument because the function

View File

@@ -0,0 +1,373 @@
import torch
import numpy as np
from gguf.constants import GGMLQuantizationType
def get_awq_format(w, group_size=128, w_bit=4):
org_w_shape = w.shape
ori_w_dtype = torch.get_default_dtype()
assert w_bit == 4
assert w.shape[1] % group_size == 0
in_features = org_w_shape[1]
w = w.reshape(-1, group_size)
assert torch.isnan(w).sum() == 0
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**w_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = zeros.view(org_w_shape[0], -1)
scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
scales = scales.t().contiguous() # input // group, o
zeros = zeros.t().contiguous() # input // group, o
# from auto awq
scale_zeros = zeros * scales
scales = scales.clone().to(ori_w_dtype)
pack_num = 32 // w_bit
intweight = []
for idx in range(in_features):
intweight.append(
torch.round(
(w[:, idx] + scale_zeros[idx // group_size])
/ scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * w_bit)
zeros = zeros.to(dtype=torch.int32, device=qweight.device)
qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * w_bit),
dtype=torch.int32,
device=zeros.device,
)
for col in range(zeros.shape[1] // pack_num):
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * w_bit)
return qweight, qzeros, scales
GGML_BLOCK_SIZES = {
"F32": 4,
"F16": 2,
"Q4_0": 2 + 16,
"Q5_0": 2 + 4 + 16,
"Q8_0": 2 + 32,
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
"Q4_K": 2 + 2 + 12 + 256 // 2,
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
}
def dequantize_f32(data):
return np.frombuffer(data, dtype=np.float32)
def dequantize_f16(data):
return np.frombuffer(data, dtype=np.float16)
def dequantize_q4_0(data):
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
return np.concatenate([
scales * ((qs & 0xf).astype(np.int8) - 8),
scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1)
def dequantize_q5_0(data):
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
bits = np.unpackbits(qh, axis=-1, bitorder="little")
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
return np.concatenate([
scales * x0,
scales * x1,
], axis=1)
def dequantize_q8_0(data):
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs
def dequantize_q2_k(data):
block_size = GGML_BLOCK_SIZES["Q2_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
tmp = np.stack([
qs[:, 00:16] >> 0,
qs[:, 16:32] >> 0,
qs[:, 00:16] >> 2,
qs[:, 16:32] >> 2,
qs[:, 00:16] >> 4,
qs[:, 16:32] >> 4,
qs[:, 00:16] >> 6,
qs[:, 16:32] >> 6,
qs[:, 32:48] >> 0,
qs[:, 48:64] >> 0,
qs[:, 32:48] >> 2,
qs[:, 48:64] >> 2,
qs[:, 32:48] >> 4,
qs[:, 48:64] >> 4,
qs[:, 32:48] >> 6,
qs[:, 48:64] >> 6,
], axis=1)
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q3_k(data):
block_size = GGML_BLOCK_SIZES["Q3_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
bits = 4 ^ (bits << 2)
qs = data_u8[:, 32:32 + 64].astype(np.int16)
a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
scales[:, 0] = (a & 15) | ((c & 3) << 4)
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
return d * (scales - 32) * np.stack([
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1)
def dequantize_q4_k(data, device=None):
block_size = GGML_BLOCK_SIZES["Q4_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# Casting to float32 because float16 is very slow on CPU
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
# Interleave low and high quantized bits
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
# Dequantize final weights using scales and offsets
weight = factors * qs2 - offsets
if device is None:
return weight
return torch.from_numpy(weight).to(device=device)
def dequantize_q5_k(data):
block_size = GGML_BLOCK_SIZES["Q5_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)
qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)
bits = np.unpackbits(qh, axis=-1, bitorder="little")
qs_hi_4 = qs >> 4
qs_lo_4 = qs & 15
scales_lo_6 = scales[:, :8] & 63
scales_hi_6 = scales[:, :8] >> 6
scales_lo_4 = scales[:, 8:] & 15
scales_hi_4 = scales[:, 8:] >> 4
m1 = dmin * scales_lo_6[:, 4]
m2 = dmin * scales_lo_6[:, 5]
m3 = dmin * scales_lo_6[:, 6]
m4 = dmin * scales_lo_6[:, 7]
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
d1 = d * scales_lo_6[:, 0]
d2 = d * scales_lo_6[:, 1]
d3 = d * scales_lo_6[:, 2]
d4 = d * scales_lo_6[:, 3]
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
return np.concatenate([
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1)
def dequantize_q6_k(data, device = None):
block_size = GGML_BLOCK_SIZES["Q6_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
# TODO use uint8 and cast later?
ql = data_u8[:, :128].astype(np.int16)
qh = data_u8[:, 128:192].astype(np.int16)
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
# Unpack bits, subtraction requires signed data type
q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
# Dequantize
weight = scales * np.concatenate([
sc[:, 0] * q1[:, :16],
sc[:, 1] * q1[:, 16:],
sc[:, 2] * q2[:, :16],
sc[:, 3] * q2[:, 16:],
sc[:, 4] * q3[:, :16],
sc[:, 5] * q3[:, 16:],
sc[:, 6] * q4[:, :16],
sc[:, 7] * q4[:, 16:],
sc[:, 8] * q5[:, :16],
sc[:, 9] * q5[:, 16:],
sc[:, 10] * q6[:, :16],
sc[:, 11] * q6[:, 16:],
sc[:, 12] * q7[:, :16],
sc[:, 13] * q7[:, 16:],
sc[:, 14] * q8[:, :16],
sc[:, 15] * q8[:, 16:],
], axis=1)
if device is None:
return weight
return torch.from_numpy(weight).to(device=device)
QK_K = 256
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
def dequantize_iq4_xs(data):
block_size = GGML_BLOCK_SIZES["IQ4_XS"]
num_blocks = len(data) // block_size
d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)
scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]
scales_l = data_u8[:, :4].reshape(num_blocks, 4)
qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)
ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)
for ib in range(QK_K // 32):
ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)
dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)
qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf
qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4
y = np.zeros((num_blocks, QK_K), dtype=np.float32)
for ib in range(QK_K // 32):
y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]
y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]
return y.flatten()
GGML_DEQUANTIZE = {
int(GGMLQuantizationType.F32): dequantize_f32,
int(GGMLQuantizationType.F16): dequantize_f16,
int(GGMLQuantizationType.Q4_0): dequantize_q4_0,
int(GGMLQuantizationType.Q5_0): dequantize_q5_0,
int(GGMLQuantizationType.Q8_0): dequantize_q8_0,
int(GGMLQuantizationType.Q2_K): dequantize_q2_k,
int(GGMLQuantizationType.Q3_K): dequantize_q3_k,
int(GGMLQuantizationType.Q4_K): dequantize_q4_k,
int(GGMLQuantizationType.Q5_K): dequantize_q5_k,
int(GGMLQuantizationType.Q6_K): dequantize_q6_k,
int(GGMLQuantizationType.IQ4_XS): dequantize_iq4_xs,
}
def dequant_gguf(data, type, shape):
values = GGML_DEQUANTIZE[type](data)
values = torch.from_numpy(values).view(shape)
return values

View File

@@ -255,18 +255,6 @@ def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tenso
return w2_packed.size(1) * marlin_tile_size
def marlin_make_workspace(
output_size_per_partition: int, device: torch.device
) -> torch.Tensor:
max_workspace_size = (
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
) * GPTQ_MARLIN_MAX_PARALLEL
return torch.zeros(
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
)
def marlin_make_workspace_new(
device: torch.device, max_blocks_per_sm: int = 1
) -> torch.Tensor:
@@ -297,12 +285,6 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
)
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
)
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices

View File

@@ -175,7 +175,7 @@ try:
op_func=_dequant_mxfp4,
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
dequant_mxfp4 = None
except AttributeError as error:
raise error
@@ -185,6 +185,6 @@ try:
op_func=_quant_dequant_mxfp4,
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
quant_dequant_mxfp4 = None
except AttributeError as error:
raise error

View File

@@ -271,12 +271,12 @@ def scaled_quantize(
If None, uses input dtype. Use torch.float32 for higher precision.
"""
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, (
"currently `scaled_quantize` only supports floating point dtypes "
"but could be extended to support other dtypes"
)
# assert quant_dtype.is_floating_point, (
# "currently `scaled_quantize` only supports floating point dtypes "
# "but could be extended to support other dtypes"
# )
finfo = torch.finfo(quant_dtype)
finfo = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)
# Convert to compute dtype if specified
x_compute = x if compute_dtype is None else x.to(compute_dtype)

View File

@@ -0,0 +1,114 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
class W8a16Config(QuantizationConfig):
"""Config class for W8a16.
"""
def __init__(
self,
) -> None:
pass
def __repr__(self) -> str:
return ("W8a16Config")
def get_name(self) -> str:
return "w8a16"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
def get_min_capability(self) -> int:
return 75
@staticmethod
def get_config_filenames():
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8a16Config":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["W8a16LinearMethod"]:
if isinstance(layer, LinearBase):
return W8a16LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class W8a16LinearMethod(LinearMethodBase):
"""Linear method for w8a16.
"""
def __init__(self, quant_config: W8a16Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
})
scales = Parameter(
torch.empty(
1,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": None,
"output_dim": 1,
})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.weight
scales = layer.scales
out_shape = (x.shape[:-1] + (qweight.shape[-2],))
reshaped_x = x.reshape(-1, x.shape[-1])
out = ops.linear_w8a16(reshaped_x, qweight, scales, format="TN")
if bias is not None:
out = out + bias
return out.reshape(out_shape)