restruct compressed_tensors_w8a8_fp8 (#5475)

This commit is contained in:
Xiaoyu Zhang
2025-04-19 19:52:15 +08:00
committed by GitHub
parent dca90f1db8
commit bf86c5e990
4 changed files with 222 additions and 243 deletions

View File

@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from sglang.srt.layers.quantization.fp8_utils import (
Fp8LinearOp,
apply_fp8_linear,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
@@ -29,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
@classmethod
def get_min_capability(cls) -> int:
@@ -149,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
use_per_token_if_dynamic=True,
compressed_tensor_quant=True,
)