Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
align_fp4_moe_weights_for_fi,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
@@ -42,92 +32,15 @@ __all__ = [
|
||||
"reorder_w1w3_to_w3w1",
|
||||
]
|
||||
|
||||
#
|
||||
# Methods used by the oracle for kernel selection.
|
||||
#
|
||||
|
||||
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""Supports non-gated MoE."""
|
||||
return True
|
||||
|
||||
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports Nvfp4 quantization."""
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
|
||||
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
"""Monolithic kernels need to express router support."""
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
]
|
||||
|
||||
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""
|
||||
TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
|
||||
the naive dispatch/combine path. DeepEP HT only implements dispatch() for
|
||||
the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
|
||||
"""
|
||||
return not moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
|
||||
def is_supported_config_trtllm(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason(f"current device {current_platform.device_name}")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_quant_scheme(weight_key, activation_key):
|
||||
return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
|
||||
elif not _supports_routing_method(moe_config.routing_method):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason(f"activation format {activation_format}")
|
||||
elif moe_config.hidden_dim % 512 != 0:
|
||||
return False, _make_reason(
|
||||
f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
|
||||
)
|
||||
|
||||
return True, None
|
||||
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(100)
|
||||
)
|
||||
|
||||
|
||||
def reorder_w1w3_to_w3w1(
|
||||
@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
custom_routing_function: object | None,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
router_logits: Router logits for expert selection
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
num_expert_group: Number of expert groups (for grouped routing)
|
||||
topk_group: Top-k within each group
|
||||
custom_routing_function: Custom routing function (e.g., Llama4)
|
||||
e_score_correction_bias: Optional routing bias correction
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert activation in SUPPORTED_ACTIVATIONS, (
|
||||
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
|
||||
f"TRTLLM FP4 MoE, {activation} found instead."
|
||||
)
|
||||
|
||||
# Quantize input to FP4
|
||||
if isinstance(x, tuple):
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||
routing_method_type = layer.routing_method_type
|
||||
if use_llama4_routing:
|
||||
routing_method_type = flashinfer.RoutingMethodType.Llama4
|
||||
|
||||
# Cast to Fp32 (required by kernel).
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits
|
||||
)
|
||||
|
||||
# Determine activation type
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group if num_expert_group is not None else 0,
|
||||
topk_group=topk_group if topk_group is not None else 0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=routing_method_type,
|
||||
do_finalize=True,
|
||||
activation_type=activation_type,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def flashinfer_trtllm_fp4_routed_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
|
||||
input top k expert indices and scores rather than computing
|
||||
top k expert indices from scores.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer with weights and scales
|
||||
x: Input tensor
|
||||
topk_ids: Ids of selected experts
|
||||
top_k: Number of experts to select per token
|
||||
activation: Activation function to use
|
||||
global_num_experts: Total number of experts across all ranks
|
||||
|
||||
Returns:
|
||||
Output tensor from the MoE layer
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
|
||||
assert activation == MoEActivation.SILU, (
|
||||
"Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
|
||||
f"{activation} found instead."
|
||||
)
|
||||
|
||||
# Pack top k ids and expert weights into a single int32 tensor, as
|
||||
# required by TRT-LLM
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
if isinstance(x, tuple):
|
||||
# Hidden_states is the already quantized
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
topk_ids=packed_tensor,
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states_fp4,
|
||||
hidden_states_scale=hidden_states_scale_linear_fp4.view(
|
||||
torch.float8_e4m3fn
|
||||
).reshape(*hidden_states_fp4.shape[:-1], -1),
|
||||
gemm1_weights=layer.w13_weight.data,
|
||||
gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm1_bias=None,
|
||||
gemm1_alpha=None,
|
||||
gemm1_beta=None,
|
||||
gemm1_clamp_limit=None,
|
||||
gemm2_weights=layer.w2_weight.data,
|
||||
gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
|
||||
gemm2_bias=None,
|
||||
output1_scale_scalar=layer.g1_scale_c.data,
|
||||
output1_scale_gate_scalar=layer.g1_alphas.data,
|
||||
output2_scale_scalar=layer.g2_alphas.data,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=0,
|
||||
topk_group=0,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routed_scaling_factor=None,
|
||||
routing_method_type=1,
|
||||
do_finalize=True,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
backend: "NvFp4MoeBackend",
|
||||
layer: "FusedMoE",
|
||||
@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
)
|
||||
)
|
||||
layer.intermediate_size_per_partition = padded_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = padded_intermediate
|
||||
|
||||
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
|
||||
w13,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
|
||||
|
||||
|
||||
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
return activation_to_flashinfer_type(activation).value
|
||||
|
||||
|
||||
def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
|
||||
from flashinfer.fused_moe.core import ActivationType
|
||||
|
||||
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
|
||||
@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
|
||||
MoEActivation.GELU: ActivationType.Geglu,
|
||||
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
|
||||
}
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation].value
|
||||
return ACTIVATION_TO_FI_ACTIVATION[activation]
|
||||
|
||||
|
||||
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
)
|
||||
|
||||
|
||||
def register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> None:
|
||||
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
|
||||
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
layer.w2_input_scale_inv = 1.0 / w2_input_scale
|
||||
layer.output1_scales_gate_scalar = g1_alphas
|
||||
|
||||
if layer.activation.is_gated:
|
||||
layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
|
||||
else:
|
||||
layer.output1_scales_scalar = (
|
||||
torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
|
||||
)
|
||||
layer.output2_scales_scalar = g2_alphas
|
||||
|
||||
|
||||
def apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
global_num_experts: int,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
|
||||
assert (
|
||||
hasattr(layer, "output1_scales_scalar")
|
||||
and hasattr(layer, "output1_scales_gate_scalar")
|
||||
and hasattr(layer, "output2_scales_scalar")
|
||||
)
|
||||
|
||||
if layer.routing_method_type == RoutingMethodType.Llama4:
|
||||
assert (
|
||||
not layer.renormalize
|
||||
and layer.custom_routing_function == Llama4MoE.custom_routing_function
|
||||
), (
|
||||
"FusedMoE flashinfer kernels with Llama4 routing method are only "
|
||||
"supported for Llama4"
|
||||
)
|
||||
else:
|
||||
assert layer.custom_routing_function is None, (
|
||||
"Custom routing function is only supported for Llama4"
|
||||
)
|
||||
activation_type = activation_to_flashinfer_int(layer.activation)
|
||||
|
||||
return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
input_scale=layer.w13_input_scale,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
output1_scales_scalar=layer.output1_scales_scalar,
|
||||
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
|
||||
output2_scales_scalar=layer.output2_scales_scalar,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
activation_type=activation_type,
|
||||
)
|
||||
|
||||
|
||||
def make_fp8_moe_alpha_scales_for_fi(
|
||||
w13_scale: torch.Tensor,
|
||||
w13_input_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g1_alphas = (w13_scale * w13_input_scale).squeeze()
|
||||
g2_alphas = (w2_scale * w2_input_scale).squeeze()
|
||||
|
||||
return g1_alphas, g2_alphas
|
||||
|
||||
|
||||
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||
backend_map = {
|
||||
"throughput": FlashinferMoeBackend.CUTLASS,
|
||||
@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
min_alignment,
|
||||
)
|
||||
layer.intermediate_size_per_partition = new_intermediate
|
||||
layer.moe_config.intermediate_size_per_partition = new_intermediate
|
||||
|
||||
# FI kernels require W31 layout rather than W13.
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
|
||||
w13_scale = swap_w13_to_w31(w13_scale)
|
||||
|
||||
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
|
||||
# and registration of alpha scales. Note that we do not register
|
||||
# as nn.Parameters since they are not needed for weight-reloading.
|
||||
# and registration of alpha scales.
|
||||
if is_trtllm and not block_quant:
|
||||
assert w13_input_scale is not None
|
||||
assert w2_input_scale is not None
|
||||
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
w13_scale=w13_scale,
|
||||
w13_input_scale=w13_input_scale,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
)
|
||||
|
||||
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
|
||||
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused
|
||||
|
||||
@@ -53,7 +53,10 @@ logger = init_logger(__name__)
|
||||
def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.dtype
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
try:
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# We need to pass in the is_hopper flag as argument because the function
|
||||
|
||||
373
vllm/model_executor/layers/quantization/utils/gguf_utils.py
Normal file
373
vllm/model_executor/layers/quantization/utils/gguf_utils.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf.constants import GGMLQuantizationType
|
||||
|
||||
def get_awq_format(w, group_size=128, w_bit=4):
|
||||
org_w_shape = w.shape
|
||||
ori_w_dtype = torch.get_default_dtype()
|
||||
assert w_bit == 4
|
||||
assert w.shape[1] % group_size == 0
|
||||
|
||||
in_features = org_w_shape[1]
|
||||
w = w.reshape(-1, group_size)
|
||||
assert torch.isnan(w).sum() == 0
|
||||
|
||||
max_val = w.amax(dim=1, keepdim=True)
|
||||
min_val = w.amin(dim=1, keepdim=True)
|
||||
max_int = 2**w_bit - 1
|
||||
min_int = 0
|
||||
scales = (max_val - min_val).clamp(min=1e-5) / max_int
|
||||
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
|
||||
w = (
|
||||
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
|
||||
) * scales
|
||||
zeros = zeros.view(org_w_shape[0], -1)
|
||||
scales = scales.view(org_w_shape[0], -1)
|
||||
w = w.reshape(org_w_shape)
|
||||
assert torch.isnan(scales).sum() == 0
|
||||
assert torch.isnan(w).sum() == 0
|
||||
|
||||
scales = scales.t().contiguous() # input // group, o
|
||||
zeros = zeros.t().contiguous() # input // group, o
|
||||
|
||||
# from auto awq
|
||||
scale_zeros = zeros * scales
|
||||
scales = scales.clone().to(ori_w_dtype)
|
||||
|
||||
pack_num = 32 // w_bit
|
||||
intweight = []
|
||||
for idx in range(in_features):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(w[:, idx] + scale_zeros[idx // group_size])
|
||||
/ scales[idx // group_size]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.to(dtype=torch.int32)
|
||||
|
||||
qweight = torch.zeros(
|
||||
(intweight.shape[0], intweight.shape[1] // 32 * w_bit),
|
||||
dtype=torch.int32,
|
||||
device=intweight.device,
|
||||
)
|
||||
|
||||
for col in range(intweight.shape[1] // pack_num):
|
||||
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
|
||||
for i in range(pack_num):
|
||||
qweight_col = intweight[:, col * pack_num + order_map[i]]
|
||||
qweight[:, col] |= qweight_col << (i * w_bit)
|
||||
|
||||
zeros = zeros.to(dtype=torch.int32, device=qweight.device)
|
||||
|
||||
qzeros = torch.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * w_bit),
|
||||
dtype=torch.int32,
|
||||
device=zeros.device,
|
||||
)
|
||||
|
||||
for col in range(zeros.shape[1] // pack_num):
|
||||
order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
|
||||
for i in range(pack_num):
|
||||
qzero_col = zeros[:, col * pack_num + order_map[i]]
|
||||
qzeros[:, col] |= qzero_col << (i * w_bit)
|
||||
|
||||
return qweight, qzeros, scales
|
||||
|
||||
GGML_BLOCK_SIZES = {
|
||||
"F32": 4,
|
||||
"F16": 2,
|
||||
"Q4_0": 2 + 16,
|
||||
"Q5_0": 2 + 4 + 16,
|
||||
"Q8_0": 2 + 32,
|
||||
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
|
||||
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
|
||||
"Q4_K": 2 + 2 + 12 + 256 // 2,
|
||||
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
|
||||
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
|
||||
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
|
||||
}
|
||||
|
||||
def dequantize_f32(data):
|
||||
return np.frombuffer(data, dtype=np.float32)
|
||||
|
||||
def dequantize_f16(data):
|
||||
return np.frombuffer(data, dtype=np.float16)
|
||||
|
||||
def dequantize_q4_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
|
||||
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
|
||||
|
||||
return np.concatenate([
|
||||
scales * ((qs & 0xf).astype(np.int8) - 8),
|
||||
scales * ((qs >> 4).astype(np.int8) - 8),
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q5_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
|
||||
qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
|
||||
qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
|
||||
|
||||
bits = np.unpackbits(qh, axis=-1, bitorder="little")
|
||||
|
||||
x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
|
||||
x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
|
||||
|
||||
return np.concatenate([
|
||||
scales * x0,
|
||||
scales * x1,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q8_0(data):
|
||||
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
|
||||
|
||||
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
|
||||
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
|
||||
return scales * qs
|
||||
|
||||
def dequantize_q2_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q2_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
|
||||
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
|
||||
|
||||
tmp = np.stack([
|
||||
qs[:, 00:16] >> 0,
|
||||
qs[:, 16:32] >> 0,
|
||||
qs[:, 00:16] >> 2,
|
||||
qs[:, 16:32] >> 2,
|
||||
qs[:, 00:16] >> 4,
|
||||
qs[:, 16:32] >> 4,
|
||||
qs[:, 00:16] >> 6,
|
||||
qs[:, 16:32] >> 6,
|
||||
qs[:, 32:48] >> 0,
|
||||
qs[:, 48:64] >> 0,
|
||||
qs[:, 32:48] >> 2,
|
||||
qs[:, 48:64] >> 2,
|
||||
qs[:, 32:48] >> 4,
|
||||
qs[:, 48:64] >> 4,
|
||||
qs[:, 32:48] >> 6,
|
||||
qs[:, 48:64] >> 6,
|
||||
], axis=1)
|
||||
|
||||
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
|
||||
|
||||
|
||||
def dequantize_q3_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q3_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
|
||||
bits = 4 ^ (bits << 2)
|
||||
qs = data_u8[:, 32:32 + 64].astype(np.int16)
|
||||
a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
|
||||
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
|
||||
scales[:, 0] = (a & 15) | ((c & 3) << 4)
|
||||
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
|
||||
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
|
||||
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
|
||||
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
|
||||
|
||||
return d * (scales - 32) * np.stack([
|
||||
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
|
||||
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
|
||||
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
|
||||
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
|
||||
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
|
||||
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
|
||||
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
|
||||
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
|
||||
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
|
||||
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
|
||||
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
|
||||
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
|
||||
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
|
||||
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
|
||||
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
|
||||
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q4_k(data, device=None):
|
||||
block_size = GGML_BLOCK_SIZES["Q4_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
# Casting to float32 because float16 is very slow on CPU
|
||||
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
|
||||
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
|
||||
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
|
||||
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
|
||||
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
|
||||
# Interleave low and high quantized bits
|
||||
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
|
||||
# Dequantize final weights using scales and offsets
|
||||
weight = factors * qs2 - offsets
|
||||
if device is None:
|
||||
return weight
|
||||
return torch.from_numpy(weight).to(device=device)
|
||||
|
||||
def dequantize_q5_k(data):
|
||||
block_size = GGML_BLOCK_SIZES["Q5_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
|
||||
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
|
||||
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
|
||||
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
|
||||
qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)
|
||||
qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)
|
||||
|
||||
bits = np.unpackbits(qh, axis=-1, bitorder="little")
|
||||
|
||||
qs_hi_4 = qs >> 4
|
||||
qs_lo_4 = qs & 15
|
||||
|
||||
scales_lo_6 = scales[:, :8] & 63
|
||||
scales_hi_6 = scales[:, :8] >> 6
|
||||
scales_lo_4 = scales[:, 8:] & 15
|
||||
scales_hi_4 = scales[:, 8:] >> 4
|
||||
|
||||
m1 = dmin * scales_lo_6[:, 4]
|
||||
m2 = dmin * scales_lo_6[:, 5]
|
||||
m3 = dmin * scales_lo_6[:, 6]
|
||||
m4 = dmin * scales_lo_6[:, 7]
|
||||
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
|
||||
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
|
||||
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
|
||||
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
|
||||
|
||||
d1 = d * scales_lo_6[:, 0]
|
||||
d2 = d * scales_lo_6[:, 1]
|
||||
d3 = d * scales_lo_6[:, 2]
|
||||
d4 = d * scales_lo_6[:, 3]
|
||||
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
|
||||
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
|
||||
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
|
||||
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
|
||||
|
||||
return np.concatenate([
|
||||
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
|
||||
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
|
||||
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
|
||||
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
|
||||
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
|
||||
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
|
||||
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
|
||||
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
|
||||
], axis=1)
|
||||
|
||||
def dequantize_q6_k(data, device = None):
|
||||
block_size = GGML_BLOCK_SIZES["Q6_K"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
|
||||
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
|
||||
|
||||
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
|
||||
# TODO use uint8 and cast later?
|
||||
ql = data_u8[:, :128].astype(np.int16)
|
||||
qh = data_u8[:, 128:192].astype(np.int16)
|
||||
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
|
||||
|
||||
# Unpack bits, subtraction requires signed data type
|
||||
q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
|
||||
q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
|
||||
q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
|
||||
q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
|
||||
q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
|
||||
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
|
||||
q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
|
||||
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
|
||||
|
||||
# Dequantize
|
||||
weight = scales * np.concatenate([
|
||||
sc[:, 0] * q1[:, :16],
|
||||
sc[:, 1] * q1[:, 16:],
|
||||
sc[:, 2] * q2[:, :16],
|
||||
sc[:, 3] * q2[:, 16:],
|
||||
sc[:, 4] * q3[:, :16],
|
||||
sc[:, 5] * q3[:, 16:],
|
||||
sc[:, 6] * q4[:, :16],
|
||||
sc[:, 7] * q4[:, 16:],
|
||||
sc[:, 8] * q5[:, :16],
|
||||
sc[:, 9] * q5[:, 16:],
|
||||
sc[:, 10] * q6[:, :16],
|
||||
sc[:, 11] * q6[:, 16:],
|
||||
sc[:, 12] * q7[:, :16],
|
||||
sc[:, 13] * q7[:, 16:],
|
||||
sc[:, 14] * q8[:, :16],
|
||||
sc[:, 15] * q8[:, 16:],
|
||||
], axis=1)
|
||||
|
||||
if device is None:
|
||||
return weight
|
||||
return torch.from_numpy(weight).to(device=device)
|
||||
|
||||
QK_K = 256
|
||||
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
|
||||
|
||||
def dequantize_iq4_xs(data):
|
||||
block_size = GGML_BLOCK_SIZES["IQ4_XS"]
|
||||
num_blocks = len(data) // block_size
|
||||
|
||||
d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)
|
||||
scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)
|
||||
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]
|
||||
scales_l = data_u8[:, :4].reshape(num_blocks, 4)
|
||||
qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)
|
||||
|
||||
ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)
|
||||
for ib in range(QK_K // 32):
|
||||
ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)
|
||||
|
||||
dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)
|
||||
|
||||
qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf
|
||||
qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4
|
||||
|
||||
y = np.zeros((num_blocks, QK_K), dtype=np.float32)
|
||||
for ib in range(QK_K // 32):
|
||||
y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]
|
||||
y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]
|
||||
|
||||
return y.flatten()
|
||||
|
||||
GGML_DEQUANTIZE = {
|
||||
int(GGMLQuantizationType.F32): dequantize_f32,
|
||||
int(GGMLQuantizationType.F16): dequantize_f16,
|
||||
int(GGMLQuantizationType.Q4_0): dequantize_q4_0,
|
||||
int(GGMLQuantizationType.Q5_0): dequantize_q5_0,
|
||||
int(GGMLQuantizationType.Q8_0): dequantize_q8_0,
|
||||
int(GGMLQuantizationType.Q2_K): dequantize_q2_k,
|
||||
int(GGMLQuantizationType.Q3_K): dequantize_q3_k,
|
||||
int(GGMLQuantizationType.Q4_K): dequantize_q4_k,
|
||||
int(GGMLQuantizationType.Q5_K): dequantize_q5_k,
|
||||
int(GGMLQuantizationType.Q6_K): dequantize_q6_k,
|
||||
int(GGMLQuantizationType.IQ4_XS): dequantize_iq4_xs,
|
||||
}
|
||||
|
||||
|
||||
def dequant_gguf(data, type, shape):
|
||||
values = GGML_DEQUANTIZE[type](data)
|
||||
values = torch.from_numpy(values).view(shape)
|
||||
return values
|
||||
@@ -255,18 +255,6 @@ def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tenso
|
||||
return w2_packed.size(1) * marlin_tile_size
|
||||
|
||||
|
||||
def marlin_make_workspace(
|
||||
output_size_per_partition: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
|
||||
) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
|
||||
return torch.zeros(
|
||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_make_workspace_new(
|
||||
device: torch.device, max_blocks_per_sm: int = 1
|
||||
) -> torch.Tensor:
|
||||
@@ -297,12 +285,6 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(
|
||||
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
@@ -175,7 +175,7 @@ try:
|
||||
op_func=_dequant_mxfp4,
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
dequant_mxfp4 = None
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -185,6 +185,6 @@ try:
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
quant_dequant_mxfp4 = None
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
@@ -271,12 +271,12 @@ def scaled_quantize(
|
||||
If None, uses input dtype. Use torch.float32 for higher precision.
|
||||
"""
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, (
|
||||
"currently `scaled_quantize` only supports floating point dtypes "
|
||||
"but could be extended to support other dtypes"
|
||||
)
|
||||
# assert quant_dtype.is_floating_point, (
|
||||
# "currently `scaled_quantize` only supports floating point dtypes "
|
||||
# "but could be extended to support other dtypes"
|
||||
# )
|
||||
|
||||
finfo = torch.finfo(quant_dtype)
|
||||
finfo = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)
|
||||
|
||||
# Convert to compute dtype if specified
|
||||
x_compute = x if compute_dtype is None else x.to(compute_dtype)
|
||||
|
||||
Reference in New Issue
Block a user