restruct compressed_tensors_w8a8_fp8 (#5475)
This commit is contained in:
@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
Fp8LinearOp,
|
||||
apply_fp8_linear,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
||||
@@ -29,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@@ -149,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(
|
||||
return apply_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
use_per_token_if_dynamic=True,
|
||||
compressed_tensor_quant=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user