Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -18,6 +18,7 @@ QuantizationMethods = Literal[
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"modelopt_mxfp8",
|
||||
"modelopt_mixed",
|
||||
"gguf",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
@@ -32,6 +33,7 @@ QuantizationMethods = Literal[
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
"cpu_awq",
|
||||
"w8a16"
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@@ -120,12 +122,18 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .inc import INCConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
|
||||
from .modelopt import (
|
||||
ModelOptFp8Config,
|
||||
ModelOptMixedPrecisionConfig,
|
||||
ModelOptMxFp8Config,
|
||||
ModelOptNvFp4Config,
|
||||
)
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .petit import PetitNvFp4Config
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .torchao import TorchAOConfig
|
||||
from .w8a16 import W8a16Config
|
||||
|
||||
method_to_config: dict[str, type[QuantizationConfig]] = {
|
||||
"awq": AWQConfig,
|
||||
@@ -135,6 +143,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"modelopt_mxfp8": ModelOptMxFp8Config,
|
||||
"modelopt_mixed": ModelOptMixedPrecisionConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
@@ -151,6 +160,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"mxfp4": Mxfp4Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"cpu_awq": CPUAWQConfig,
|
||||
"w8a16": W8a16Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@@ -9,6 +9,7 @@ from torch.nn import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -60,6 +61,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -197,7 +199,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
# from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
# if is_layer_skipped(
|
||||
# prefix,
|
||||
@@ -213,9 +215,10 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
# layer, prefix
|
||||
# )
|
||||
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
# moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
# moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
# return moe_quant_method
|
||||
return AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -389,13 +392,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
replace_parameter(layer, "qweight", pad_qweight)
|
||||
replace_parameter(layer, "qzeros", pad_qzeros)
|
||||
replace_parameter(layer, "scales", pad_scales)
|
||||
return
|
||||
|
||||
# TODO(gyf) Marlin format is not support for now..
|
||||
device = layer.qweight.device
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
return
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
@@ -811,49 +814,33 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# return fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# getattr(layer, "w13_bias", None),
|
||||
# getattr(layer, "w2_bias", None),
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
# input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
# quant_type_id=self.quant_type.id,
|
||||
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# global_num_experts=layer.global_num_experts,
|
||||
# expert_map=layer.expert_map,
|
||||
# w1_zeros=layer.w13_qzeros,
|
||||
# w2_zeros=layer.w2_qzeros,
|
||||
# workspace=layer.workspace,
|
||||
# input_dtype=self.input_dtype,
|
||||
# inplace=not self.moe.disable_inplace,
|
||||
# )
|
||||
|
||||
num_tokens, num_experts = router_logits.shape
|
||||
assert layer.activation.value == "silu", "Only SiLU activation is supported."
|
||||
use_ep = layer.expert_map is not None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata:
|
||||
if isinstance(attn_metadata, dict):
|
||||
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
|
||||
else:
|
||||
only_decode = use_ep == False and attn_metadata.num_decodes > 0 and attn_metadata.num_prefills == 0
|
||||
else:
|
||||
only_decode = False
|
||||
|
||||
if use_ep:
|
||||
start_eid = layer.ep_rank * layer.local_num_experts
|
||||
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = layer.global_num_experts
|
||||
if use_ep:
|
||||
hidden_size = x.shape[1]
|
||||
(
|
||||
@@ -875,7 +862,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * top_k
|
||||
expand_tokens = num_tokens * layer.top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
@@ -885,7 +872,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
|
||||
# expand + reorder
|
||||
# TODO use kernel
|
||||
@@ -893,76 +879,130 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
hidden_states=x,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=top_k,
|
||||
topk=layer.top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# w4a16 group gemm 1
|
||||
# pt_output_1: (expand_tokens, 2n) dtype
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
if only_decode:
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemv(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemv(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=routed_scaling_factor
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
|
||||
else:
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * layer.top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
return final_hidden_states
|
||||
# return torch.ops.vllm.fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# router_logits,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# w1_zeros=layer.w13_qzeros,
|
||||
# w2_zeros=layer.w2_qzeros,
|
||||
# num_bits=self.quant_config.weight_bits,
|
||||
# )
|
||||
|
||||
@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
|
||||
)
|
||||
from compressed_tensors.transform import TransformConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -52,7 +51,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16,
|
||||
CompressedTensorsW4A8Int8
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod,
|
||||
@@ -401,8 +399,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
)
|
||||
is_tensor = (
|
||||
weight_strategy
|
||||
@@ -420,8 +418,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
)
|
||||
is_token = (
|
||||
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
|
||||
@@ -663,12 +661,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
if envs.VLLM_W8A8_LINEAR_USE_W4A8:
|
||||
return CompressedTensorsW4A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=input_quant.symmetric,
|
||||
)
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8, CompressedTensorsW4A8Int8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
|
||||
|
||||
@@ -28,5 +28,4 @@ __all__ = [
|
||||
"CompressedTensorsW4A4Fp4",
|
||||
"CompressedTensorsW4A8Int",
|
||||
"CompressedTensorsW4A8Fp8",
|
||||
"CompressedTensorsW4A8Int8"
|
||||
]
|
||||
|
||||
@@ -25,11 +25,18 @@ logger = init_logger(__name__)
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool, is_w4a8_linear: bool = False
|
||||
):
|
||||
self.strategy = strategy
|
||||
import vllm.envs as env
|
||||
if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR":
|
||||
self.strategy = QuantizationStrategy.TENSOR
|
||||
elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL":
|
||||
self.strategy = QuantizationStrategy.CHANNEL
|
||||
else:
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
self.is_w4a8_linear = is_w4a8_linear
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -53,16 +60,32 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
input_symmetric=self.input_symmetric,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
remainder = input_size_per_partition % 64
|
||||
if remainder != 0:
|
||||
input_size_per_partition_padded = input_size_per_partition + (64 - remainder)
|
||||
else:
|
||||
input_size_per_partition_padded = input_size_per_partition
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.is_w4a8_linear:
|
||||
# only "NN" is supported
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
input_size_per_partition_padded,
|
||||
sum(output_partition_sizes) // 2,
|
||||
dtype=torch.int8),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition_padded,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
@@ -109,104 +132,4 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int8(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
self.kernel = init_int8_linear_kernel(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric,
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
# weight = ModelWeightParameter(
|
||||
# data=torch.empty(
|
||||
# sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
# ),
|
||||
# input_dim=1,
|
||||
# output_dim=0,
|
||||
# weight_loader=weight_loader,
|
||||
# )
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
sum(output_partition_sizes) // 2,
|
||||
dtype=torch.int8
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
input_zero_point = None
|
||||
input_scale = None
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
if not hasattr(layer, "azp_adj"):
|
||||
layer.register_parameter("azp_adj", None)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
return self.kernel.apply_weights(layer, x, bias, self.is_w4a8_linear)
|
||||
|
||||
@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
FusedMoeWeightScaleSupported,
|
||||
MoEActivation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
|
||||
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
# TRTLLM does not use Modular Kernel.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
|
||||
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
|
||||
a1_scale = layer.w13_input_scale
|
||||
@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
# TODO(rob): convert this to MK.
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Expected 'silu' activation but got {layer.activation}"
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||
w2_weight=layer.w2_weight,
|
||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.moe_mk is not None
|
||||
assert not self.is_monolithic
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from gguf import GGMLQuantizationType as WeightType
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
@@ -234,7 +235,7 @@ try:
|
||||
op_func=_fused_mul_mat_gguf,
|
||||
fake_impl=_fused_mul_mat_gguf_fake,
|
||||
)
|
||||
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
|
||||
fused_mul_mat_gguf = _fused_mul_mat_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
@@ -365,7 +366,7 @@ try:
|
||||
op_func=_fused_moe_gguf,
|
||||
fake_impl=_fused_moe_gguf_fake,
|
||||
)
|
||||
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
|
||||
fused_moe_gguf = _fused_moe_gguf
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
@@ -410,7 +411,7 @@ try:
|
||||
op_func=_apply_gguf_embedding,
|
||||
fake_impl=_apply_gguf_embedding_fake,
|
||||
)
|
||||
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
|
||||
apply_gguf_embedding = _apply_gguf_embedding
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
@@ -451,6 +452,9 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
"data_container": [],
|
||||
"shard_id": [],
|
||||
"shard_id_map": {},
|
||||
"params_dtype": params_dtype,
|
||||
"input_size_per_partition" :input_size_per_partition, # restore shape for qkv and merge
|
||||
"output_partition_sizes" :output_partition_sizes,
|
||||
},
|
||||
)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
@@ -664,6 +668,10 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
"""
|
||||
|
||||
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
return F.embedding(x, weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
qweight = layer.qweight
|
||||
qweight_type = layer.qweight_type.weight_type
|
||||
hidden_size = qweight.tensor_shape[1]
|
||||
|
||||
@@ -128,7 +128,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
|
||||
@@ -59,9 +59,164 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
#[B,K//8,N] ->[B,K,N]
|
||||
# less memmory
|
||||
def unpack_k_batch_opt(packed_w: torch.Tensor, num_bits: int = 4, chunk_size: int = 2) -> torch.Tensor:
|
||||
"""
|
||||
Memory-efficient unpacking for 3D tensor.
|
||||
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor,
|
||||
without broadcasting huge intermediate tensors (avoids OOM).
|
||||
|
||||
Args:
|
||||
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
|
||||
num_bits: Number of bits per packed element (e.g., 4 or 2).
|
||||
chunk_size: How many bit groups to unpack at once (tradeoff between speed and memory).
|
||||
|
||||
Returns:
|
||||
unpacked: torch.int8 tensor of shape [B, K, N].
|
||||
"""
|
||||
B, k_packed, N = packed_w.shape
|
||||
pack_factor = 32 // num_bits
|
||||
K = k_packed * pack_factor
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
# Allocate output tensor once
|
||||
unpacked = torch.empty((B, K, N), dtype=torch.int8, device=packed_w.device)
|
||||
|
||||
# Process bit chunks iteratively to save memory
|
||||
for i in range(0, pack_factor, chunk_size):
|
||||
# Precompute shifts for this chunk
|
||||
shift_vals = num_bits * torch.arange(i, min(i + chunk_size, pack_factor), device=packed_w.device)
|
||||
# [chunk_size, 1, 1, 1]
|
||||
shifts = shift_vals.view(-1, 1, 1, 1)
|
||||
# Compute small chunk only
|
||||
chunk = ((packed_w.unsqueeze(0) >> shifts) & mask).to(torch.int8)
|
||||
|
||||
# chunk: [chunk_size, B, k_packed, N]
|
||||
# write into output
|
||||
for j in range(chunk.shape[0]):
|
||||
unpacked[:, (i + j)::pack_factor, :] = chunk[j]
|
||||
|
||||
del chunk # release memory early
|
||||
|
||||
return unpacked
|
||||
|
||||
# more memmory
|
||||
def unpack_k_batch(packed_w: torch.Tensor, num_bits: int = 4) -> torch.Tensor:
|
||||
"""
|
||||
Efficient vectorized unpacking for 3D tensor (batch version).
|
||||
Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor.
|
||||
|
||||
Args:
|
||||
packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
|
||||
num_bits: Number of bits per packed element (e.g., 4).
|
||||
|
||||
Returns:
|
||||
unpacked: torch.int8 tensor of shape [B, K, N].
|
||||
"""
|
||||
B, k_packed, n = packed_w.shape
|
||||
pack_factor = 32 // num_bits
|
||||
k = k_packed * pack_factor
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
# [pack_factor, 1, 1, 1]
|
||||
shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1, 1)
|
||||
|
||||
# [1, B, k_packed, N]
|
||||
packed_expanded = packed_w.unsqueeze(0)
|
||||
|
||||
# Extract each group of num_bits using bitwise ops
|
||||
unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
|
||||
|
||||
# [pack_factor, B, k_packed, N] → [B, K, N]
|
||||
unpacked = unpacked_groups.permute(1, 2, 0, 3).reshape(B, k, n)
|
||||
|
||||
return unpacked
|
||||
|
||||
|
||||
#[B,K,N] ->[B,K,N//8]
|
||||
# less memmory
|
||||
def pack_n_batch_opt(x: torch.Tensor, pack_num: int = 8, order_map=None, chunk_size: int = 2) -> torch.Tensor:
|
||||
"""
|
||||
Memory-efficient batch packing with correct bit order.
|
||||
[B, K, N] int4 -> [B, K, N//pack_num] int32.
|
||||
"""
|
||||
B, K, N = x.shape
|
||||
assert N % pack_num == 0, "N must be divisible by pack_num"
|
||||
cols = N // pack_num
|
||||
unit = 32 // pack_num
|
||||
|
||||
if order_map is None:
|
||||
order_map = list(range(pack_num))
|
||||
order_map = torch.tensor(order_map, device=x.device)
|
||||
|
||||
shifts = unit * torch.arange(pack_num, device=x.device) # always 0..unit*(pack_num-1)
|
||||
packed = torch.zeros((B, K, cols), dtype=torch.int32, device=x.device)
|
||||
x_reshape = x.view(B, K, cols, pack_num) & 0xF
|
||||
|
||||
# process in chunks for memory efficiency
|
||||
for start in range(0, pack_num, chunk_size):
|
||||
end = min(start + chunk_size, pack_num)
|
||||
idx_chunk = order_map[start:end]
|
||||
shift_chunk = shifts[start:end]
|
||||
|
||||
vals = torch.gather(x_reshape, 3, idx_chunk.view(1,1,1,-1).expand(B,K,cols,-1)).to(torch.int32)
|
||||
for j in range(vals.shape[-1]):
|
||||
packed.add_(vals[..., j] << shift_chunk[j])
|
||||
|
||||
return packed
|
||||
|
||||
## more memmory
|
||||
def pack_n_batch(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
|
||||
"""
|
||||
Efficient vectorized batch packing: [B, K, N] int4 -> [B, K, N//pack_num] int32.
|
||||
|
||||
Args:
|
||||
x: torch.int32 tensor of shape [B, K, N], each element 0-15 (int4).
|
||||
pack_num: Number of 4-bit elements per packed int32 (default=8).
|
||||
order_map: Optional order of elements within each packed int32.
|
||||
|
||||
Returns:
|
||||
torch.int32 tensor of shape [B, K, N//pack_num].
|
||||
"""
|
||||
|
||||
B, K, N = x.shape
|
||||
assert N % pack_num == 0, "N must be divisible by pack_num"
|
||||
cols = N // pack_num
|
||||
|
||||
if order_map is None:
|
||||
order_map = list(range(pack_num))
|
||||
order_map = torch.tensor(order_map, device=x.device)
|
||||
|
||||
unit = 32 // pack_num # number of bits per element
|
||||
|
||||
# reshape to [B, K, cols, pack_num]
|
||||
pack_num_int = int(pack_num)
|
||||
|
||||
x_reshape = x.view(B, K, cols, pack_num_int)
|
||||
|
||||
# reorder according to order_map
|
||||
x_reorder = torch.gather(
|
||||
x_reshape, 3, order_map.view(1, 1, 1, -1).expand(B, K, cols, -1)
|
||||
)
|
||||
|
||||
# mask low 4 bits
|
||||
x_reorder = x_reorder & 0xF
|
||||
|
||||
# bit shifts [pack_num] -> [1,1,1,pack_num] broadcastable
|
||||
shifts = (unit * torch.arange(pack_num_int, device=x.device)).view(1, 1, 1, -1)
|
||||
|
||||
# shift and sum along last dimension to combine bits
|
||||
packed = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
|
||||
|
||||
return packed
|
||||
|
||||
|
||||
|
||||
def get_moe_quant_method(
|
||||
config: "GPTQMarlinConfig",
|
||||
@@ -495,8 +650,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
if self.quant_config.quant_type.size_bits == 4:
|
||||
self.quant_type = scalar_types.uint4b8
|
||||
elif self.quant_config.quant_type.size_bits == 8:
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
# elif self.quant_config.quant_type.size_bits == 8:
|
||||
# self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
self.input_dtype = None
|
||||
@@ -594,7 +749,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -606,7 +761,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
num_experts,
|
||||
scales_size2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -656,7 +811,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
device = layer.w13_qweight.device
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
# layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
@@ -673,119 +828,111 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_scales.data = layer.w2_scales.data * 512
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# num_experts = layer.w13_g_idx.shape[0]
|
||||
# w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
# w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
# w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
# w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
# for e in range(num_experts):
|
||||
# w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
|
||||
# torch.int32
|
||||
# )
|
||||
# w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
# torch.int32
|
||||
# )
|
||||
# w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
|
||||
# w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
# replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
# replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
# replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
# replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
# else:
|
||||
# # Reset g_idx related tensors
|
||||
# num_experts = layer.w13_g_idx.shape[0]
|
||||
# device = layer.w13_g_idx.device
|
||||
# layer.w13_g_idx = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w2_g_idx = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
# torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
# requires_grad=False,
|
||||
# )
|
||||
# # Repack weights
|
||||
# marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
# layer.w13_qweight,
|
||||
# layer.w13_g_idx_sort_indices,
|
||||
# layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
# layer.w13_qweight.shape[2],
|
||||
# self.quant_config.quant_type.size_bits,
|
||||
# )
|
||||
# replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
# marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
# layer.w2_qweight,
|
||||
# layer.w2_g_idx_sort_indices,
|
||||
# layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
# layer.w2_qweight.shape[2],
|
||||
# self.quant_config.quant_type.size_bits,
|
||||
# )
|
||||
# replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# # Repack scales
|
||||
# marlin_w13_scales = marlin_moe_permute_scales(
|
||||
# s=layer.w13_scales,
|
||||
# size_k=layer.intermediate_size_per_partition,
|
||||
# size_n=layer.w13_scales.shape[2],
|
||||
# group_size=self.quant_config.group_size,
|
||||
# )
|
||||
# replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
# marlin_w2_scales = marlin_moe_permute_scales(
|
||||
# s=layer.w2_scales,
|
||||
# size_k=layer.w2_scales.shape[1]
|
||||
# * (
|
||||
# self.quant_config.group_size
|
||||
# if self.quant_config.group_size != -1
|
||||
# else self.quant_config.pack_factor
|
||||
# ),
|
||||
# size_n=layer.w2_scales.shape[2],
|
||||
# group_size=self.quant_config.group_size,
|
||||
# )
|
||||
# replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
# The modular kernel expects w13_weight and w2_weight,
|
||||
# but GPTQ uses w13_qweight and w2_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w13_weight = layer.w13_qweight
|
||||
# Alias for modular kernel
|
||||
layer.w2_weight = layer.w2_qweight
|
||||
# if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
# layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1]
|
||||
* (
|
||||
self.quant_config.group_size
|
||||
if self.quant_config.group_size != -1
|
||||
else self.quant_config.pack_factor
|
||||
),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
|
||||
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
# if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
# layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
if self.quant_config.desc_act:
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod now not support desc_act. please fix it")
|
||||
w13_qweight_unpacked = unpack_k_batch(layer.w13_qweight)
|
||||
w13_qweight_repacked = pack_n_batch(w13_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
replace_parameter(layer, "w13_qweight", w13_qweight_repacked)
|
||||
|
||||
# quant vllm/model_executor/layers/quantization/utils/quant_utils.py#quantize_weights
|
||||
# if quant_type.has_bias():
|
||||
# w_q += quant_type.bias
|
||||
# use quant_type.bias as zp,(ixformer support)
|
||||
w13_zp = torch.full_like(layer.w13_scales, self.quant_type.bias, dtype=torch.int32)
|
||||
w13_zp_pack = pack_n_batch(w13_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
|
||||
replace_parameter(layer, "w13_qzeros", w13_zp_pack)
|
||||
|
||||
w2_qweight_unpacked = unpack_k_batch(layer.w2_qweight)
|
||||
w2_qweight_repacked = pack_n_batch(w2_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
|
||||
replace_parameter(layer, "w2_qweight", w2_qweight_repacked)
|
||||
|
||||
w2_zp = torch.full_like(layer.w2_scales, self.quant_type.bias, dtype=torch.int32)
|
||||
w2_zp_pack = pack_n_batch(w2_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
|
||||
replace_parameter(layer, "w2_qzeros", w2_zp_pack)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -900,30 +1047,165 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
getattr(layer, "w13_bias", None),
|
||||
getattr(layer, "w2_bias", None),
|
||||
layer.w13_scales,
|
||||
layer.w2_scales,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
g_idx1=layer.w13_g_idx,
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
assert layer.activation.value == "silu", "Only SiLU activation is supported."
|
||||
use_ep = layer.expert_map is not None
|
||||
|
||||
if use_ep:
|
||||
start_eid = layer.ep_rank * layer.local_num_experts
|
||||
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
|
||||
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
if (hasattr(layer, "w13_bias") and layer.w13_bias is not None) or (hasattr(layer, "w2_bias") and layer.w2_bias is not None):
|
||||
raise NotImplementedError(
|
||||
"GPTQMarlinMoEMethod moe_w4a16_group_gemm not supported bias, please fix this")
|
||||
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = layer.global_num_experts
|
||||
|
||||
if use_ep:
|
||||
hidden_size = x.shape[1]
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
expand_tokens,
|
||||
) = ixfops.moe_compute_token_index_ep(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
start_expert_id=start_eid,
|
||||
end_expert_id=end_eid,
|
||||
)
|
||||
if expert_sizes_cpu.sum() == 0:
|
||||
return torch.zeros(
|
||||
(num_tokens, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * layer.top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
expert_sizes_gpu,
|
||||
expert_sizes_cpu,
|
||||
) = ixfops.moe_compute_token_index(
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
|
||||
# expand + reorder
|
||||
# TODO use kernel
|
||||
expand_hidden_states = ixfops.moe_expand_input(
|
||||
hidden_states=x,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=layer.top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# w4a16 group gemm 1
|
||||
# pt_output_1: (expand_tokens, 2n) dtype
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * layer.top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# return torch.ops.vllm.fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# getattr(layer, "w13_bias", None),
|
||||
# getattr(layer, "w2_bias", None),
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# router_logits,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# quant_type_id=self.quant_type.id,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
# global_num_experts=global_num_experts,
|
||||
# expert_map=expert_map,
|
||||
# g_idx1=layer.w13_g_idx,
|
||||
# g_idx2=layer.w2_g_idx,
|
||||
# sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
# sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
# workspace=layer.workspace,
|
||||
# is_k_full=self.is_k_full)
|
||||
|
||||
@@ -12,8 +12,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
NvFp4MoeBackend,
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_nvfp4_moe_kernel,
|
||||
@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
@@ -114,6 +104,8 @@ QUANT_ALGOS = [
|
||||
"NVFP4",
|
||||
# MXFP8
|
||||
"MXFP8",
|
||||
# MIXED_PRECISION,
|
||||
"MIXED_PRECISION",
|
||||
]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
@@ -181,7 +173,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
# handle kv-cache first so we can focus only on weight quantization thereafter
|
||||
if isinstance(layer, Attention):
|
||||
if isinstance(layer, (Attention, MLAAttention)):
|
||||
return self.KVCacheMethodCls(self)
|
||||
|
||||
# handle exclusion
|
||||
@@ -235,6 +227,26 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
|
||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
|
||||
|
||||
@staticmethod
|
||||
def _extract_modelopt_quant_algo(
|
||||
hf_quant_cfg: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
"""Extract upper-cased quant_algo from a modelopt config.
|
||||
|
||||
Returns the quant_algo string (upper-cased), or None if the config
|
||||
is not a modelopt config.
|
||||
"""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
|
||||
return None
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
return str(quant_config.get("quant_algo", "")).upper()
|
||||
return None
|
||||
return str(hf_quant_cfg.get("quant_algo", "")).upper()
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
@@ -272,10 +284,20 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
||||
exclude_modules = quant_config.get("exclude_modules", [])
|
||||
else:
|
||||
# Compressed-tensors style format:
|
||||
# Compressed-tensors style format (config.json quantization_config):
|
||||
# {"quant_algo": "...", "quant_method": "modelopt"}
|
||||
quant_method = config.get("quant_algo")
|
||||
kv_cache_quant_method = config.get("kv_cache_quant_algo")
|
||||
|
||||
# "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
|
||||
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||
if isinstance(kv_cache_scheme, dict) and (
|
||||
kv_cache_scheme.get("type") == "float"
|
||||
and kv_cache_scheme.get("num_bits") == 8
|
||||
):
|
||||
kv_cache_quant_method = "FP8"
|
||||
else:
|
||||
kv_cache_quant_method = None
|
||||
|
||||
# "ignore" is the key in config.json
|
||||
exclude_modules = config.get("ignore", [])
|
||||
group_size_raw = config.get("group_size")
|
||||
@@ -379,32 +401,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt config should be used based on
|
||||
quantization config."""
|
||||
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = str(quant_config.get("quant_algo", ""))
|
||||
if quant_algo.upper() == "FP8":
|
||||
return "modelopt"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
|
||||
if quant_algo.upper() == "FP8":
|
||||
return "modelopt"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and algo == "FP8":
|
||||
return "modelopt"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -737,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -745,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -862,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w13 = layer.w13_weight
|
||||
@@ -904,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
a1_scale = layer.w13_input_scale
|
||||
@@ -920,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -931,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
|
||||
)
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert layer.activation in SUPPORTED_ACTIVATIONS, (
|
||||
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
|
||||
f"TRTLLM FP4 MoE, {layer.activation} found instead."
|
||||
)
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -964,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
assert layer.activation in (
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {layer.activation}"
|
||||
)
|
||||
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -1031,32 +1003,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt FP4 config should be used based on
|
||||
quantization config."""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = quant_config.get("quant_algo", "")
|
||||
if "NVFP4" in quant_algo:
|
||||
return "modelopt_fp4"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific
|
||||
# quant_algo field
|
||||
quant_algo = hf_quant_cfg.get("quant_algo", "")
|
||||
if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
|
||||
return "modelopt_fp4"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and ("NVFP4" in algo or "FP4" in algo):
|
||||
return "modelopt_fp4"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -1249,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -1434,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
|
||||
replace_parameter(layer, "w2_input_scale", a2_scale)
|
||||
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
@property
|
||||
def do_post_quant_allgather(self):
|
||||
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
|
||||
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise RuntimeError(
|
||||
"prepare_dp_allgather_tensor is only supported for "
|
||||
"FlashInfer TRTLLM NVFP4 MoE backend."
|
||||
)
|
||||
|
||||
import flashinfer
|
||||
|
||||
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
|
||||
hidden_states,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
|
||||
return hidden_states_fp4, extra_tensors
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return make_nvfp4_moe_quant_config(
|
||||
backend=self.nvfp4_backend,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
@@ -1493,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not self.moe.moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1507,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not layer.enable_eplb
|
||||
)
|
||||
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=layer.top_k,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -1534,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# EPLB path
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
assert layer.enable_eplb
|
||||
return flashinfer_trtllm_fp4_routed_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||
@@ -1619,31 +1502,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt MXFP8 config should be used based on
|
||||
quantization config."""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = str(quant_config.get("quant_algo", "")).upper()
|
||||
if "MXFP8" in quant_algo:
|
||||
return "modelopt_mxfp8"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
|
||||
if "MXFP8" in quant_algo:
|
||||
return "modelopt_mxfp8"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and "MXFP8" in algo:
|
||||
return "modelopt_mxfp8"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -1841,3 +1702,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
# Register the method classes for ModelOptMxFp8Config
|
||||
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
|
||||
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
|
||||
|
||||
class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
|
||||
"""Config class for ModelOpt MIXED_PRECISION.
|
||||
|
||||
Supports checkpoints where different layers use different quantization
|
||||
algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
|
||||
The per-layer algorithm is specified in the ``quantized_layers`` dict
|
||||
inside ``config.json``'s ``quantization_config`` (preferred) or the
|
||||
legacy ``hf_quant_config.json``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
quantized_layers: dict[str, dict[str, Any]],
|
||||
fp8_config: ModelOptFp8Config,
|
||||
nvfp4_config: ModelOptNvFp4Config,
|
||||
) -> None:
|
||||
super().__init__(exclude_modules)
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
self.quantized_layers = quantized_layers
|
||||
self.fp8_config = fp8_config
|
||||
self.nvfp4_config = nvfp4_config
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "modelopt_mixed"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and algo == "MIXED_PRECISION":
|
||||
return "modelopt_mixed"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _from_config(
|
||||
cls,
|
||||
*,
|
||||
quant_method: str,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
original_config: dict[str, Any],
|
||||
group_size: int | None,
|
||||
**kwargs: Any,
|
||||
) -> "ModelOptMixedPrecisionConfig":
|
||||
if "quantization" in original_config:
|
||||
quantized_layers = original_config["quantization"].get(
|
||||
"quantized_layers", {}
|
||||
)
|
||||
else:
|
||||
quantized_layers = original_config.get("quantized_layers", {})
|
||||
|
||||
if not quantized_layers:
|
||||
raise ValueError(
|
||||
"MIXED_PRECISION quant_algo requires a non-empty "
|
||||
"'quantized_layers' mapping in the quantization config."
|
||||
)
|
||||
|
||||
# Determine group_size from the first NVFP4 entry if not provided.
|
||||
if group_size is None:
|
||||
for layer_info in quantized_layers.values():
|
||||
if layer_info.get("quant_algo", "").upper() == "NVFP4":
|
||||
group_size = layer_info.get("group_size", 16)
|
||||
break
|
||||
if group_size is None:
|
||||
group_size = 16
|
||||
|
||||
fp8_config = ModelOptFp8Config(
|
||||
quant_method="FP8",
|
||||
is_checkpoint_fp8_serialized=True,
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=[],
|
||||
)
|
||||
nvfp4_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=kv_cache_quant_method,
|
||||
exclude_modules=[],
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
return cls(
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=exclude_modules,
|
||||
quantized_layers=quantized_layers,
|
||||
fp8_config=fp8_config,
|
||||
nvfp4_config=nvfp4_config,
|
||||
)
|
||||
|
||||
def _resolve_quant_algo(self, prefix: str) -> str | None:
|
||||
"""Look up the quant_algo for a vLLM-side layer prefix.
|
||||
|
||||
Tries three strategies in order:
|
||||
1. Direct lookup in ``quantized_layers``.
|
||||
2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
|
||||
3. Prefix-based lookup for FusedMoE (any child key starts with
|
||||
``prefix + "."``).
|
||||
|
||||
Returns the upper-cased quant_algo string, or *None* if the prefix
|
||||
is not found.
|
||||
"""
|
||||
# 1. Direct lookup
|
||||
if prefix in self.quantized_layers:
|
||||
return self.quantized_layers[prefix]["quant_algo"].upper()
|
||||
|
||||
# 2. Packed / fused layer lookup
|
||||
proj_name = prefix.rsplit(".", 1)[-1]
|
||||
if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
|
||||
algos: set[str] = set()
|
||||
base = prefix.rsplit(".", 1)[0]
|
||||
for shard_name in self.packed_modules_mapping[proj_name]:
|
||||
shard_prefix = f"{base}.{shard_name}"
|
||||
if shard_prefix in self.quantized_layers:
|
||||
algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
|
||||
if len(algos) == 1:
|
||||
return algos.pop()
|
||||
if len(algos) > 1:
|
||||
raise ValueError(
|
||||
f"Mixed quant_algo within fused layer {prefix}: "
|
||||
f"{algos}. All shards must use the same quantization."
|
||||
)
|
||||
|
||||
# 3. Prefix-based lookup (for FusedMoE / parent modules)
|
||||
prefix_dot = prefix + "."
|
||||
for key, info in self.quantized_layers.items():
|
||||
if key.startswith(prefix_dot):
|
||||
return info["quant_algo"].upper()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
"""Return quantize-method based on layer."""
|
||||
# KV-cache quantization
|
||||
if isinstance(layer, Attention):
|
||||
if self.kv_cache_quant_method:
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
# Excluded layers
|
||||
if self.is_layer_excluded(prefix):
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
return None
|
||||
|
||||
quant_algo = self._resolve_quant_algo(prefix)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if quant_algo == "FP8":
|
||||
return ModelOptFp8LinearMethod(self.fp8_config)
|
||||
if quant_algo == "NVFP4":
|
||||
return ModelOptNvFp4LinearMethod(self.nvfp4_config)
|
||||
# Layer not in quantized_layers — leave unquantized
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if quant_algo == "FP8":
|
||||
return ModelOptFp8MoEMethod(
|
||||
quant_config=self.fp8_config,
|
||||
moe_config=layer.moe_config,
|
||||
)
|
||||
if quant_algo == "NVFP4":
|
||||
return ModelOptNvFp4FusedMoE(
|
||||
quant_config=self.nvfp4_config,
|
||||
moe_config=layer.moe_config,
|
||||
)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
super().apply_vllm_mapper(hf_to_vllm_mapper)
|
||||
if self.quantized_layers:
|
||||
self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
@@ -77,6 +78,8 @@ class Mxfp4Backend(Enum):
|
||||
# Triton Backend
|
||||
TRITON = 6
|
||||
|
||||
CK = 7
|
||||
|
||||
|
||||
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
|
||||
"""
|
||||
@@ -167,9 +170,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
||||
elif current_platform.is_xpu():
|
||||
logger.info_once("Using xpu backend on XPU")
|
||||
return Mxfp4Backend.MARLIN
|
||||
elif current_platform.is_rocm() and has_triton_kernels():
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
elif current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
|
||||
if rocm_aiter_ops.is_enabled() and on_gfx950():
|
||||
logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
|
||||
return Mxfp4Backend.CK
|
||||
elif has_triton_kernels():
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
|
||||
return Mxfp4Backend.NONE
|
||||
|
||||
@@ -257,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
|
||||
self.moe_mk: mk.FusedMoEModularKernel | None = None
|
||||
self.moe_kernel: mk.FusedMoEKernel | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
|
||||
self.intermediate_pad = (
|
||||
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
|
||||
)
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
@@ -427,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
self.moe_mk = mk.FusedMoEModularKernel(
|
||||
self.moe_kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
MarlinExperts(
|
||||
self.moe,
|
||||
@@ -776,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
self.moe_mk = mk.FusedMoEModularKernel(
|
||||
self.moe_kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
FlashInferExperts(
|
||||
moe_config=self.moe,
|
||||
@@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
),
|
||||
shared_experts=None,
|
||||
)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.CK:
|
||||
if layer.w13_bias is not None:
|
||||
layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
|
||||
if layer.w2_bias.data is not None:
|
||||
layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)
|
||||
|
||||
e, n, k = layer.w13_weight.shape
|
||||
layer.w13_weight.view(torch.uint8).copy_(
|
||||
layer.w13_weight.data.view(torch.uint8)
|
||||
.view(e, n // 2, 2, k)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.view(e, n, k)
|
||||
)
|
||||
layer.w13_weight_scale.data = (
|
||||
layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.view(e, n, -1)
|
||||
)
|
||||
layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
|
||||
layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)
|
||||
|
||||
layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
|
||||
layer.w13_weight, 16, True
|
||||
)
|
||||
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
|
||||
layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
|
||||
self.num_experts,
|
||||
True,
|
||||
)
|
||||
|
||||
layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
|
||||
layer.w2_weight, 16, False
|
||||
)
|
||||
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
|
||||
layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
|
||||
self.num_experts,
|
||||
False,
|
||||
)
|
||||
|
||||
layer.w13_bias.data = (
|
||||
layer.w13_bias.data.view(-1, n // 2, 2)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
.view(-1, n)
|
||||
)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
shuffled_w13_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
shuffled_w2_scale, requires_grad=False
|
||||
)
|
||||
# replace_parameter(layer, "w13_bias", w13_bias)
|
||||
# replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
|
||||
# replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
|
||||
# replace_parameter(layer, "w13_weight", w13_weight)
|
||||
# replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
@@ -792,18 +865,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
|
||||
# (stored in self.fused_experts) to determine if the MoE has a
|
||||
# batched activation format. As self.fused_experts is not
|
||||
# initialized at this point, we resort to checking the MoE config
|
||||
# directly.
|
||||
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
|
||||
is_batched_moe = self.moe.use_deepep_ll_kernels
|
||||
if is_batched_moe:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
@@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
self.w13_weight = w13_weight
|
||||
self.w2_weight = w2_weight
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = w13_weight
|
||||
layer.w2_weight = w2_weight
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
|
||||
@@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
elif self.mxfp4_backend in [
|
||||
Mxfp4Backend.SM100_FI_MXFP4_BF16,
|
||||
Mxfp4Backend.SM90_FI_MXFP4_BF16,
|
||||
Mxfp4Backend.CK,
|
||||
]:
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
@@ -882,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
@@ -929,10 +1001,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
if self.moe.is_lora_enabled:
|
||||
return False
|
||||
return (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
or self.mxfp4_backend == Mxfp4Backend.TRITON
|
||||
or self.mxfp4_backend == Mxfp4Backend.CK
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -968,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
or self.mxfp4_backend == Mxfp4Backend.MARLIN
|
||||
)
|
||||
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -1054,6 +1129,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
tune_max_num_tokens=max(self.max_capture_size, 1),
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.CK:
|
||||
topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
|
||||
x, router_logits, layer.top_k, True
|
||||
)
|
||||
output = rocm_aiter_ops.fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
|
||||
quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
doweight_stage1=False,
|
||||
hidden_pad=self.hidden_pad // 128 * 128,
|
||||
intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
|
||||
bias1=layer.w13_bias,
|
||||
bias2=layer.w2_bias,
|
||||
)
|
||||
return output
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward,
|
||||
@@ -1162,7 +1258,7 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
topk_weights=routing_weights,
|
||||
topk_ids=selected_experts,
|
||||
n_experts_per_token=layer.top_k,
|
||||
activation=layer.activation,
|
||||
activation=layer.activation.value,
|
||||
num_experts=layer.local_num_experts,
|
||||
is_mxfp4=True,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
@@ -26,10 +25,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PTPCFp8Config(Fp8Config):
|
||||
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
|
||||
|
||||
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
self.dynamic_mxfp4_quant = False
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
self.hf_config = get_config(
|
||||
model=model_name,
|
||||
trust_remote_code=False, # or get from model_config if available
|
||||
revision=revision,
|
||||
config_format="auto",
|
||||
)
|
||||
|
||||
quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_config is not None:
|
||||
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
|
||||
model_type = self.hf_config.model_type
|
||||
if quant_dtype == "fp4" and model_type == "deepseek_v3":
|
||||
self.dynamic_mxfp4_quant = True
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
|
||||
if should_ignore_layer(
|
||||
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
if (
|
||||
"self_attn" not in prefix # only quantize attention projections
|
||||
or not getattr(self, "dynamic_mxfp4_quant", False)
|
||||
or not isinstance(layer, LinearBase) # Ignore other methods
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
scheme = self.get_scheme(
|
||||
layer=layer,
|
||||
layer_name=prefix,
|
||||
dynamic_mxfp4_quant=True,
|
||||
)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
)
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
def _get_scheme_from_config(
|
||||
self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
|
||||
) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
input_symmetric=input_config.get("symmetric"),
|
||||
)
|
||||
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX(weight_config, input_config)
|
||||
return QuarkOCP_MX(
|
||||
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No quark compatible scheme was found. "
|
||||
@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
|
||||
f"Input config: {input_config}"
|
||||
)
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
|
||||
) -> "QuarkScheme":
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
scheme = self._get_scheme_from_config(
|
||||
layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
|
||||
__all__ = [
|
||||
"QuarkMoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod_OSS",
|
||||
]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
"output_tensors and bias "
|
||||
"quantized are not supported"
|
||||
)
|
||||
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w4a8(weight_config, input_config):
|
||||
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
||||
emulate = not current_platform.supports_mx() or not (
|
||||
rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
if (
|
||||
input_config.get("dtype") == "fp8_e4m3"
|
||||
and not input_config.get("is_dynamic")
|
||||
and not emulate
|
||||
):
|
||||
return QuarkOCP_MX_MoEMethod_OSS(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
else:
|
||||
return QuarkOCP_MX_MoEMethod(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
get_current_vllm_config().model_config.hf_config, "model_type", None
|
||||
)
|
||||
|
||||
self._emulate = (
|
||||
self.emulate = (
|
||||
not current_platform.supports_mx()
|
||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
||||
|
||||
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
|
||||
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
params_dtype = torch.uint8
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
if self.model_type == "gpt_oss":
|
||||
if current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
|
||||
self.unpadded_hidden_size = extra_weight_attrs.get(
|
||||
"unpadded_hidden_size", hidden_size
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.emulate:
|
||||
if (
|
||||
self.model_type == "gpt_oss"
|
||||
and self.mxfp4_backend == Mxfp4Backend.TRITON
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Triton kernel implemented fused MoE for GPT_OSS model "
|
||||
"in Quark(MoE) format is not integrated or provided yet."
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(weight_config, input_config, moe)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.static_input_scales:
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
from triton_kernels.numerics import InFlexData
|
||||
|
||||
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
|
||||
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
|
||||
)
|
||||
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return mxfp4_w4a8_moe_quant_config(
|
||||
w1_scale=self.w13_precision_config,
|
||||
w2_scale=self.w2_precision_config,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=layer.top_k,
|
||||
renormalize=layer.renormalize,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
unpadded_N_w1=self.intermediate_size_per_partition * 2,
|
||||
unpadded_K_w1=self.unpadded_hidden_size,
|
||||
unpadded_N_w2=self.unpadded_hidden_size,
|
||||
unpadded_K_w2=self.intermediate_size_per_partition,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
|
||||
|
||||
class QuarkOCP_MX(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||
self,
|
||||
weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any],
|
||||
dynamic_mxfp4_quant: bool = False,
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
|
||||
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
|
||||
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
|
||||
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
if self.dynamic_mxfp4_quant:
|
||||
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
w_s.T.contiguous(), requires_grad=False
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
|
||||
elif self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
if self.dynamic_mxfp4_quant:
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, kwargs)
|
||||
else:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
align_fp4_moe_weights_for_fi,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
@@ -42,92 +32,15 @@ __all__ = [
|
||||
"reorder_w1w3_to_w3w1",
|
||||
]
|
||||
|
||||
#
|
||||
# Methods used by the oracle for kernel selection.
|
||||
#
|
||||
|
||||
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Supports non-gated MoE."""
|
||||
return True
|
||||
|
||||
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Nvfp4 quantization."""
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
]
|
||||
|
||||
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""
|
||||
TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
|
||||
the naive dispatch/combine path. DeepEP HT only implements dispatch() for
|
||||
the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
|
||||
"""
|
||||
return not moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
|
||||
def is_supported_config_trtllm(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason(f"current device {current_platform.device_name}")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_quant_scheme(weight_key, activation_key):
|
||||
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
|
||||
elif not _supports_routing_method(moe_config.routing_method):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason(f"activation format {activation_format}")
|
||||
elif moe_config.hidden_dim % 512 != 0:
|
||||
return False, _make_reason(
|
||||
f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
|
||||
)
|
||||
|
||||
return True, None
|
||||
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
def reorder_w1w3_to_w3w1(
|
||||
@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
custom_routing_function: object | None,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
router_logits: Router logits for expert selection
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
num_expert_group: Number of expert groups (for grouped routing)
|
||||
topk_group: Top-k within each group
|
||||
custom_routing_function: Custom routing function (e.g., Llama4)
|
||||
e_score_correction_bias: Optional routing bias correction
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert activation in SUPPORTED_ACTIVATIONS, (
|
||||
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
|
||||
f"TRTLLM FP4 MoE, {activation} found instead."
|
||||
)
|
||||
|
||||
# Quantize input to FP4
|
||||
if isinstance(x, tuple):
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||
routing_method_type = layer.routing_method_type
|
||||
if use_llama4_routing:
|
||||
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||
|
||||
# Cast to Fp32 (required by kernel).
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Determine activation type
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||
topk_group=topk_group if topk_group is not None else 0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
activation_type=activation_type,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_routed_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
|
||||
input top k expert indices and scores rather than computing
|
||||
top k expert indices from scores.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
topk_ids: Ids of selected experts
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
|
||||
assert activation == MoEActivation.SILU, (
|
||||
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
|
||||
f"{activation} found instead."
|
||||
)
|
||||
|
||||
# Pack top k ids and expert weights into a single int32 tensor, as
|
||||
# required by TRT-LLM
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
if isinstance(x, tuple):
|
||||
# Hidden_states is the already quantized
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
backend: "NvFp4MoeBackend",
|
||||
layer: "FusedMoE",
|
||||
@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
)
|
||||
)
|
||||
layer.intermediate_size_per_partition = padded_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = padded_intermediate
|
||||
|
||||
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
w13,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
|
||||
|
||||
|
||||
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
return activation_to_flashinfer_type(activation).value
|
||||
|
||||
|
||||
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
|
||||
@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
MoEActivation.GELU: ActivationType.Geglu,
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation].value
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation]
|
||||
|
||||
|
||||
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
)
|
||||
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
layer.output1_scales_gate_scalar = g1_alphas
|
||||
|
||||
if layer.activation.is_gated:
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
else:
|
||||
layer.output1_scales_scalar = (
|
||||
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
|
||||
)
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
global_num_experts: int,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
if layer.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert (
|
||||
not layer.renormalize
|
||||
and layer.custom_routing_function == Llama4MoE.custom_routing_function
|
||||
), (
|
||||
"FusedMoE flashinfer kernels with Llama4 routing method are only "
|
||||
"supported for Llama4"
|
||||
)
|
||||
else:
|
||||
assert layer.custom_routing_function is None, (
|
||||
"Custom routing function is only supported for Llama4"
|
||||
)
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
input_scale=layer.w13_input_scale,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
output1_scales_scalar=layer.output1_scales_scalar,
|
||||
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
|
||||
output2_scales_scalar=layer.output2_scales_scalar,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
|
||||
|
||||
def make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g1_alphas = (w13_scale * w13_input_scale).squeeze()
|
||||
g2_alphas = (w2_scale * w2_input_scale).squeeze()
|
||||
|
||||
return g1_alphas, g2_alphas
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
min_alignment,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
# FI kernels require W31 layout rather than W13.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales. Note that we do not register
|
||||
# as nn.Parameters since they are not needed for weight-reloading.
|
||||
# and registration of alpha scales.
|
||||
if is_trtllm and not block_quant:
|
||||
assert w13_input_scale is not None
|
||||
assert w2_input_scale is not None
|
||||
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
|
||||
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused
|
||||
|
||||
@@ -53,7 +53,10 @@ logger = init_logger(__name__)
|
||||
def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.dtype
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
try:
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# We need to pass in the is_hopper flag as argument because the function
|
||||
|
||||
373
vllm/model_executor/layers/quantization/utils/gguf_utils.py
Normal file
373
vllm/model_executor/layers/quantization/utils/gguf_utils.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf.constants import GGMLQuantizationType
|
||||
|
||||
def get_awq_format(w, group_size=128, w_bit=4):
|
||||
org_w_shape = w.shape
|
||||
ori_w_dtype = torch.get_default_dtype()
|
||||
assert w_bit == 4
|
||||
assert w.shape[1] % group_size == 0
|
||||
|
||||
in_features = org_w_shape[1]
|
||||
w = w.reshape(-1, group_size)
|
||||
assert torch.isnan(w).sum() == 0
|
||||
|
||||
max_val = w.amax(dim=1, keepdim=True)
|
||||
min_val = w.amin(dim=1, keepdim=True)
|
||||
max_int = 2**w_bit - 1
|
||||
min_int = 0
|
||||
scales = (max_val - min_val).clamp(min=1e-5) / max_int
|
||||
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
|
||||
w = (
|
||||
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
|
||||
) * scales
|
||||
zeros = zeros.view(org_w_shape[0], -1)
|
||||
scales = scales.view(org_w_shape[0], -1)
|
||||
w = w.reshape(org_w_shape)
|
||||
assert torch.isnan(scales).sum() == 0
|
||||
assert torch.isnan(w).sum() == 0
|
||||
|
||||
scales = scales.t().contiguous() # input // group, o
|
||||
zeros = zeros.t().contiguous() # input // group, o
|
||||
|
||||
# from auto awq
|
||||
scale_zeros = zeros * scales
|
||||
scales = scales.clone().to(ori_w_dtype)
|
||||
|
||||
pack_num = 32 // w_bit
|
||||
intweight = []
|
||||
for idx in range(in_features):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(w[:, idx] + scale_zeros[idx // group_size])
|
||||
/ scales[idx // group_size]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.to(dtype=torch.int32)
|
||||
|
||||
qweight = torch.zeros(
|
||||
(intweight.shape[0], intweight.shape[1] // 32 * w_bit),
|
||||
dtype=torch.int32,
|
||||
device=intweight.device,
|
||||
)
|
||||
|
||||
for col in range(intweight.shape[1] // pack_num):
|
||||
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
|
||||
for i in range(pack_num):
|
||||
qweight_col = intweight[:, col * pack_num + order_map[i]]
|
||||
qweight[:, col] |= qweight_col << (i * w_bit)
|
||||
|
||||
zeros = zeros.to(dtype=torch.int32, device=qweight.device)
|
||||
|
||||
qzeros = torch.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * w_bit),
|
||||
dtype=torch.int32,
|
||||
device=zeros.device,
|
||||
)
|
||||
|
||||
for col in range(zeros.shape[1] // pack_num):
|
||||
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
|
||||
for i in range(pack_num):
|
||||
qzero_col = zeros[:, col * pack_num + order_map[i]]
|
||||
qzeros[:, col] |= qzero_col << (i * w_bit)
|
||||
|
||||
return qweight, qzeros, scales
|
||||
|
||||
GGML_BLOCK_SIZES = {
|
||||
"F32": 4,
|
||||
"F16": 2,
|
||||
"Q4_0": 2 + 16,
|
||||
"Q5_0": 2 + 4 + 16,
|
||||
"Q8_0": 2 + 32,
|
||||
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
|
||||
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
|
||||
"Q4_K": 2 + 2 + 12 + 256 // 2,
|
||||
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
|
||||
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
|
||||
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
|
||||
}
|
||||
|
||||
def dequantize_f32(data):
|
||||
return np.frombuffer(data, dtype=np.float32)
|
||||
|
||||
def dequantize_f16(data):
|
||||
return np.frombuffer(data, dtype=np.float16)
|
||||
|
||||
def dequantize_q4_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
|
||||
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
|
||||
|
||||
return np.concatenate([
|
||||
scales * ((qs & 0xf).astype(np.int8) - 8),
|
||||
scales * ((qs >> 4).astype(np.int8) - 8),
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q5_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
|
||||
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
|
||||
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
|
||||
|
||||
bits = np.unpackbits(qh, axis=-1, bitorder="little")
|
||||
|
||||
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
|
||||
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
|
||||
|
||||
return np.concatenate([
|
||||
scales * x0,
|
||||
scales * x1,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q8_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
|
||||
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
|
||||
return scales * qs
|
||||
|
||||
def dequantize_q2_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q2_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
|
||||
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
|
||||
|
||||
tmp = np.stack([
|
||||
qs[:, 00:16] >> 0,
|
||||
qs[:, 16:32] >> 0,
|
||||
qs[:, 00:16] >> 2,
|
||||
qs[:, 16:32] >> 2,
|
||||
qs[:, 00:16] >> 4,
|
||||
qs[:, 16:32] >> 4,
|
||||
qs[:, 00:16] >> 6,
|
||||
qs[:, 16:32] >> 6,
|
||||
qs[:, 32:48] >> 0,
|
||||
qs[:, 48:64] >> 0,
|
||||
qs[:, 32:48] >> 2,
|
||||
qs[:, 48:64] >> 2,
|
||||
qs[:, 32:48] >> 4,
|
||||
qs[:, 48:64] >> 4,
|
||||
qs[:, 32:48] >> 6,
|
||||
qs[:, 48:64] >> 6,
|
||||
], axis=1)
|
||||
|
||||
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
|
||||
|
||||
|
||||
def dequantize_q3_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q3_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
|
||||
bits = 4 ^ (bits << 2)
|
||||
qs = data_u8[:, 32:32 + 64].astype(np.int16)
|
||||
a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
|
||||
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
|
||||
scales[:, 0] = (a & 15) | ((c & 3) << 4)
|
||||
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
|
||||
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
|
||||
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
|
||||
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
|
||||
|
||||
return d * (scales - 32) * np.stack([
|
||||
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
|
||||
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
|
||||
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
|
||||
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
|
||||
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
|
||||
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
|
||||
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
|
||||
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
|
||||
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
|
||||
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
|
||||
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
|
||||
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
|
||||
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
|
||||
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
|
||||
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
|
||||
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q4_k(data, device=None):
|
||||
block_size = GGML_BLOCK_SIZES["Q4_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
# Casting to float32 because float16 is very slow on CPU
|
||||
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
|
||||
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
|
||||
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
|
||||
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
|
||||
# Interleave low and high quantized bits
|
||||
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
|
||||
# Dequantize final weights using scales and offsets
|
||||
weight = factors * qs2 - offsets
|
||||
if device is None:
|
||||
return weight
|
||||
return torch.from_numpy(weight).to(device=device)
|
||||
|
||||
def dequantize_q5_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q5_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
|
||||
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
|
||||
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)
|
||||
qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)
|
||||
|
||||
bits = np.unpackbits(qh, axis=-1, bitorder="little")
|
||||
|
||||
qs_hi_4 = qs >> 4
|
||||
qs_lo_4 = qs & 15
|
||||
|
||||
scales_lo_6 = scales[:, :8] & 63
|
||||
scales_hi_6 = scales[:, :8] >> 6
|
||||
scales_lo_4 = scales[:, 8:] & 15
|
||||
scales_hi_4 = scales[:, 8:] >> 4
|
||||
|
||||
m1 = dmin * scales_lo_6[:, 4]
|
||||
m2 = dmin * scales_lo_6[:, 5]
|
||||
m3 = dmin * scales_lo_6[:, 6]
|
||||
m4 = dmin * scales_lo_6[:, 7]
|
||||
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
|
||||
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
|
||||
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
|
||||
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
|
||||
|
||||
d1 = d * scales_lo_6[:, 0]
|
||||
d2 = d * scales_lo_6[:, 1]
|
||||
d3 = d * scales_lo_6[:, 2]
|
||||
d4 = d * scales_lo_6[:, 3]
|
||||
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
|
||||
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
|
||||
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
|
||||
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
|
||||
|
||||
return np.concatenate([
|
||||
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
|
||||
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
|
||||
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
|
||||
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
|
||||
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
|
||||
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
|
||||
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
|
||||
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q6_k(data, device = None):
|
||||
block_size = GGML_BLOCK_SIZES["Q6_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
|
||||
|
||||
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
|
||||
# TODO use uint8 and cast later?
|
||||
ql = data_u8[:, :128].astype(np.int16)
|
||||
qh = data_u8[:, 128:192].astype(np.int16)
|
||||
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
|
||||
|
||||
# Unpack bits, subtraction requires signed data type
|
||||
q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
|
||||
q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
|
||||
q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
|
||||
q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
|
||||
q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
|
||||
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
|
||||
q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
|
||||
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
|
||||
|
||||
# Dequantize
|
||||
weight = scales * np.concatenate([
|
||||
sc[:, 0] * q1[:, :16],
|
||||
sc[:, 1] * q1[:, 16:],
|
||||
sc[:, 2] * q2[:, :16],
|
||||
sc[:, 3] * q2[:, 16:],
|
||||
sc[:, 4] * q3[:, :16],
|
||||
sc[:, 5] * q3[:, 16:],
|
||||
sc[:, 6] * q4[:, :16],
|
||||
sc[:, 7] * q4[:, 16:],
|
||||
sc[:, 8] * q5[:, :16],
|
||||
sc[:, 9] * q5[:, 16:],
|
||||
sc[:, 10] * q6[:, :16],
|
||||
sc[:, 11] * q6[:, 16:],
|
||||
sc[:, 12] * q7[:, :16],
|
||||
sc[:, 13] * q7[:, 16:],
|
||||
sc[:, 14] * q8[:, :16],
|
||||
sc[:, 15] * q8[:, 16:],
|
||||
], axis=1)
|
||||
|
||||
if device is None:
|
||||
return weight
|
||||
return torch.from_numpy(weight).to(device=device)
|
||||
|
||||
QK_K = 256
|
||||
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
|
||||
|
||||
def dequantize_iq4_xs(data):
|
||||
block_size = GGML_BLOCK_SIZES["IQ4_XS"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)
|
||||
scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]
|
||||
scales_l = data_u8[:, :4].reshape(num_blocks, 4)
|
||||
qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)
|
||||
|
||||
ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)
|
||||
for ib in range(QK_K // 32):
|
||||
ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)
|
||||
|
||||
dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)
|
||||
|
||||
qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf
|
||||
qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4
|
||||
|
||||
y = np.zeros((num_blocks, QK_K), dtype=np.float32)
|
||||
for ib in range(QK_K // 32):
|
||||
y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]
|
||||
y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]
|
||||
|
||||
return y.flatten()
|
||||
|
||||
GGML_DEQUANTIZE = {
|
||||
int(GGMLQuantizationType.F32): dequantize_f32,
|
||||
int(GGMLQuantizationType.F16): dequantize_f16,
|
||||
int(GGMLQuantizationType.Q4_0): dequantize_q4_0,
|
||||
int(GGMLQuantizationType.Q5_0): dequantize_q5_0,
|
||||
int(GGMLQuantizationType.Q8_0): dequantize_q8_0,
|
||||
int(GGMLQuantizationType.Q2_K): dequantize_q2_k,
|
||||
int(GGMLQuantizationType.Q3_K): dequantize_q3_k,
|
||||
int(GGMLQuantizationType.Q4_K): dequantize_q4_k,
|
||||
int(GGMLQuantizationType.Q5_K): dequantize_q5_k,
|
||||
int(GGMLQuantizationType.Q6_K): dequantize_q6_k,
|
||||
int(GGMLQuantizationType.IQ4_XS): dequantize_iq4_xs,
|
||||
}
|
||||
|
||||
|
||||
def dequant_gguf(data, type, shape):
|
||||
values = GGML_DEQUANTIZE[type](data)
|
||||
values = torch.from_numpy(values).view(shape)
|
||||
return values
|
||||
@@ -255,18 +255,6 @@ def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tenso
|
||||
return w2_packed.size(1) * marlin_tile_size
|
||||
|
||||
|
||||
def marlin_make_workspace(
|
||||
output_size_per_partition: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
|
||||
) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
|
||||
return torch.zeros(
|
||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_make_workspace_new(
|
||||
device: torch.device, max_blocks_per_sm: int = 1
|
||||
) -> torch.Tensor:
|
||||
@@ -297,12 +285,6 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(
|
||||
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
@@ -175,7 +175,7 @@ try:
|
||||
op_func=_dequant_mxfp4,
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
dequant_mxfp4 = None
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -185,6 +185,6 @@ try:
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
quant_dequant_mxfp4 = None
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -271,12 +271,12 @@ def scaled_quantize(
|
||||
If None, uses input dtype. Use torch.float32 for higher precision.
|
||||
"""
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, (
|
||||
"currently `scaled_quantize` only supports floating point dtypes "
|
||||
"but could be extended to support other dtypes"
|
||||
)
|
||||
# assert quant_dtype.is_floating_point, (
|
||||
# "currently `scaled_quantize` only supports floating point dtypes "
|
||||
# "but could be extended to support other dtypes"
|
||||
# )
|
||||
|
||||
finfo = torch.finfo(quant_dtype)
|
||||
finfo = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)
|
||||
|
||||
# Convert to compute dtype if specified
|
||||
x_compute = x if compute_dtype is None else x.to(compute_dtype)
|
||||
|
||||
114
vllm/model_executor/layers/quantization/w8a16.py
Normal file
114
vllm/model_executor/layers/quantization/w8a16.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class W8a16Config(QuantizationConfig):
|
||||
"""Config class for W8a16.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("W8a16Config")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "w8a16"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames():
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W8a16Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["W8a16LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8a16LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class W8a16LinearMethod(LinearMethodBase):
|
||||
"""Linear method for w8a16.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: W8a16Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
1,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": None,
|
||||
"output_dim": 1,
|
||||
})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.weight
|
||||
scales = layer.scales
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-2],))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = ops.linear_w8a16(reshaped_x, qweight, scales, format="TN")
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
Reference in New Issue
Block a user