[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

View File

@@ -0,0 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .quark_scheme import QuarkScheme
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
from .quark_w8a8_fp8 import QuarkW8A8Fp8
from .quark_w8a8_int8 import QuarkW8A8Int8
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]

View File

@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Optional
import torch
__all__ = ["QuarkScheme"]
class QuarkScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by Quark.
"""
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError

View File

@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform
__all__ = ["QuarkW4A4MXFP4"]
class QuarkW4A4MXFP4(QuarkScheme):
def __init__(self, weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any]):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.emulate = not current_platform.supports_mx()
@classmethod
def get_min_capability(cls) -> int:
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.emulate:
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)
weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data
if not envs.VLLM_QUARK_EMU_MEM_OPT:
layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
else:
self.weight_quantizer = weight_quantizer
layer.weight_scale = None
# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=2,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.emulate:
if envs.VLLM_QUARK_EMU_MEM_OPT:
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
else:
dq_w = layer.weight
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
return F.linear(qdq_x, dq_w, bias)
else:
raise NotImplementedError()

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
__all__ = ["QuarkW8A8Fp8"]
class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.out_dtype = torch.get_default_dtype()
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.qscheme == "per_tensor":
if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
max_w_scale = layer.weight_scale
weight = layer.weight
max_w_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=max_w_scale,
logical_widths=layer.logical_widths,
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.qscheme == "per_channel":
weight = layer.weight
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
logger = init_logger(__name__)
class QuarkW8A8Int8(QuarkScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
input_symmetric: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"),
is_static_input_scheme=(self.is_static_input_scheme is True),
input_symmetric=(self.input_symmetric is True))
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
ChannelQuantZPParameter = ChannelQuantScaleParameter
weight_zero_point = ChannelQuantZPParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.int8),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
PerTensorZPParameter = PerTensorScaleParameter
weight_zero_point = PerTensorZPParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
input_zero_point = BasevLLMParameter(data=torch.empty(
1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj")
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.register_parameter("weight_zero_point", None)
delattr(layer, 'weight_zero_point')
if self.input_symmetric:
layer.register_parameter("input_zero_point", None)
delattr(layer, 'input_zero_point')
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)