compressed_tensors: port w8a16 fp8 from vllm (#4852)
This commit is contained in:
@@ -33,6 +33,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target,
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import convert_to_channelwise
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
|
||||
MARLIN_FP8_AVAILABLE = True
|
||||
except ImportError:
|
||||
MARLIN_FP8_AVAILABLE = False
|
||||
|
||||
def apply_fp8_marlin_linear(*args, **kwargs):
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
|
||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||
|
||||
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
|
||||
|
||||
|
||||
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
if not MARLIN_FP8_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales,
|
||||
# we expand each scale to its shard's channels.
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
ws_channelwise = convert_to_channelwise(
|
||||
layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
|
||||
else:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
# Weights must be transposed for marlin
|
||||
layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
|
||||
|
||||
if self.is_static_input_scheme:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
prepare_fp8_layer_for_marlin(layer, strategy="channel")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
elif self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight strategy={self.strategy}, "
|
||||
f"supported strategies are {SUPPORTED_STRATEGIES}"
|
||||
)
|
||||
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE (to deal with converted checkpoints)
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
Reference in New Issue
Block a user