Apply sgl w8a8 fp8 kernel (#3148)

This commit is contained in:
HandH1998
2025-03-09 16:03:32 +08:00
committed by GitHub
parent 9fb48f951f
commit 0dd6cda288
13 changed files with 523 additions and 37 deletions

View File

@@ -7,7 +7,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear,
convert_to_channelwise,
cutlass_fp8_supported,
requantize_with_max_scale,
)
@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
# Initialize logger for the module
logger = logging.getLogger(__name__)
@@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer.weight, layer.weight_scale, layer.logical_widths
)
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
# cutlass sgl-kernel only supports per-channel scale
if self.cutlass_fp8_supported:
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)