Apply sgl w8a8 fp8 kernel (#3148)
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear,
|
||||
convert_to_channelwise,
|
||||
cutlass_fp8_supported,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
layer.weight, layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
|
||||
# cutlass sgl-kernel only supports per-channel scale
|
||||
if self.cutlass_fp8_supported:
|
||||
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user