Sync from v0.13
This commit is contained in:
124
vllm/model_executor/layers/quantization/utils/petit_utils.py
Normal file
124
vllm/model_executor/layers/quantization/utils/petit_utils.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
# TYPE_CHECKING is used for static type analysis to prevent circular imports.
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
# 1. Create a global variable as a placeholder for the module
|
||||
_petit_kernel: Optional["ModuleType"] = None
|
||||
|
||||
_PETIT_INSTALL_MSG = (
|
||||
"Petit is not installed. Please install it with `pip install petit-kernel`."
|
||||
)
|
||||
|
||||
|
||||
def _import_petit_kernel() -> "ModuleType":
|
||||
"""
|
||||
A helper function to handle the lazy import.
|
||||
The first time this function is called, it will import the petit_kernel
|
||||
library and store it in the global _petit_kernel variable.
|
||||
Subsequent calls will return the already-loaded module directly.
|
||||
"""
|
||||
global _petit_kernel
|
||||
if _petit_kernel is not None:
|
||||
return _petit_kernel
|
||||
|
||||
try:
|
||||
import petit_kernel
|
||||
|
||||
_petit_kernel = petit_kernel
|
||||
return _petit_kernel
|
||||
except ImportError:
|
||||
# The 'from None' syntax prevents chaining the original ImportError,
|
||||
# making the traceback cleaner.
|
||||
raise ImportError(_PETIT_INSTALL_MSG) from None
|
||||
|
||||
|
||||
# The _require_petit function can now be a simple alias for consistency.
|
||||
_require_petit = _import_petit_kernel
|
||||
|
||||
|
||||
def _check_petit_nvfp4_supported(
|
||||
quant_method: str, group_size: int | None
|
||||
) -> tuple[bool, str | None]:
|
||||
if quant_method != "NVFP4":
|
||||
return (
|
||||
False,
|
||||
(
|
||||
"Petit currently only supports: NVFP4 quantizations in sglang. "
|
||||
"Please check the `hf_quant_config.json` file for your model's "
|
||||
"quant configuration."
|
||||
),
|
||||
)
|
||||
if group_size is not None and group_size != 16:
|
||||
return (
|
||||
False,
|
||||
"Petit currently only supports: group_size=16 quantizations.",
|
||||
)
|
||||
return (True, None)
|
||||
|
||||
|
||||
def verify_petit_nvfp4_supported(quant_method: str, group_size: int | None) -> None:
|
||||
supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size)
|
||||
if not supported:
|
||||
assert error_msg is not None
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
|
||||
# 2. Call _import_petit_kernel() to trigger (or get) the import.
|
||||
petit_kernel = _import_petit_kernel()
|
||||
|
||||
# Repack weights to petit format
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
qweight = layer.weight.view(torch.int32).contiguous()
|
||||
|
||||
# 3. Call functions through the imported module variable.
|
||||
petit_qweight = petit_kernel.repack_nvfp4(
|
||||
qweight, size_n=part_size_n, size_k=part_size_k
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False)
|
||||
|
||||
# Permute scales
|
||||
weight_scale = petit_kernel.process_nvfp4_scales(
|
||||
scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n
|
||||
)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
|
||||
def apply_petit_nvfp4_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_scale_2: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Trigger (or get) the import here as well.
|
||||
petit_kernel = _import_petit_kernel()
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n,)
|
||||
|
||||
# TODO: Use auto-tuning to find the performant solution_id
|
||||
# Call the function via the module variable.
|
||||
output = petit_kernel.mul_nvfp4_a16(
|
||||
a=reshaped_x,
|
||||
b=weight,
|
||||
s=weight_scale,
|
||||
global_scale=weight_scale_2,
|
||||
size_m=reshaped_x.size(0),
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
solution_id=-1,
|
||||
)
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
Reference in New Issue
Block a user