compressed_tensors: port w8a16 fp8 from vllm (#4852)
This commit is contained in:
@@ -33,6 +33,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
|||||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
CompressedTensorsW8A8Fp8,
|
CompressedTensorsW8A8Fp8,
|
||||||
|
CompressedTensorsW8A16Fp8,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||||
find_matched_target,
|
find_matched_target,
|
||||||
|
|||||||
@@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||||
|
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CompressedTensorsScheme",
|
"CompressedTensorsScheme",
|
||||||
"CompressedTensorsW8A8Fp8",
|
"CompressedTensorsW8A8Fp8",
|
||||||
|
"CompressedTensorsW8A16Fp8",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,153 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
|
from sglang.srt.layers.parameter import (
|
||||||
|
ChannelQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
PerTensorScaleParameter,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.utils import convert_to_channelwise
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
|
apply_fp8_marlin_linear,
|
||||||
|
prepare_fp8_layer_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
|
MARLIN_FP8_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MARLIN_FP8_AVAILABLE = False
|
||||||
|
|
||||||
|
def apply_fp8_marlin_linear(*args, **kwargs):
|
||||||
|
raise ImportError("vllm is not installed")
|
||||||
|
|
||||||
|
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
||||||
|
raise ImportError("vllm is not installed")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||||
|
|
||||||
|
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||||
|
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||||
|
self.strategy = strategy
|
||||||
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
|
|
||||||
|
if not MARLIN_FP8_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# ampere and up
|
||||||
|
return 80
|
||||||
|
|
||||||
|
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||||
|
# So if we have a fused module (QKV, MLP) with per tensor scales,
|
||||||
|
# we expand each scale to its shard's channels.
|
||||||
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
|
if self.strategy == QuantizationStrategy.TENSOR:
|
||||||
|
ws_channelwise = convert_to_channelwise(
|
||||||
|
layer.weight_scale, layer.logical_widths
|
||||||
|
)
|
||||||
|
layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
|
||||||
|
else:
|
||||||
|
# required by torch.compile to be torch.nn.Parameter
|
||||||
|
layer.weight_scale = torch.nn.Parameter(
|
||||||
|
layer.weight_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weights must be transposed for marlin
|
||||||
|
layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
|
||||||
|
|
||||||
|
if self.is_static_input_scheme:
|
||||||
|
# required by torch.compile to be torch.nn.Parameter
|
||||||
|
layer.input_scale = torch.nn.Parameter(
|
||||||
|
layer.input_scale.data, requires_grad=False
|
||||||
|
)
|
||||||
|
prepare_fp8_layer_for_marlin(layer, strategy="channel")
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
weight_loader: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
|
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
elif self.strategy == QuantizationStrategy.TENSOR:
|
||||||
|
weight_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported weight strategy={self.strategy}, "
|
||||||
|
f"supported strategies are {SUPPORTED_STRATEGIES}"
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_scale[:] = torch.finfo(torch.float32).min
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
# INPUT SCALE (to deal with converted checkpoints)
|
||||||
|
if self.is_static_input_scheme:
|
||||||
|
input_scale = PerTensorScaleParameter(
|
||||||
|
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return apply_fp8_marlin_linear(
|
||||||
|
input=x,
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
workspace=layer.workspace,
|
||||||
|
size_n=layer.output_size_per_partition,
|
||||||
|
size_k=layer.input_size_per_partition,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user