Add float8 dynamic quant to torchao_utils (#1528)
This commit is contained in:
@@ -18,11 +18,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
||||
"""
|
||||
# Lazy import to suppress some warnings
|
||||
from torchao.quantization import (
|
||||
float8_dynamic_activation_float8_weight,
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
quantize_,
|
||||
)
|
||||
from torchao.quantization.observer import PerRow, PerTensor
|
||||
|
||||
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
||||
dummy_linear.weight = param
|
||||
@@ -45,6 +47,22 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
||||
# this requires newer hardware
|
||||
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||
quantize_(dummy_linear, float8_weight_only())
|
||||
elif "fp8dq" in torchao_config:
|
||||
granularity = torchao_config.split("-")[-1]
|
||||
GRANULARITY_MAP = {
|
||||
"per_row": PerRow(),
|
||||
"per_tensor": PerTensor(),
|
||||
}
|
||||
assert (
|
||||
granularity in GRANULARITY_MAP
|
||||
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
|
||||
quantize_(
|
||||
dummy_linear,
|
||||
float8_dynamic_activation_float8_weight(
|
||||
granularity=GRANULARITY_MAP[granularity]
|
||||
),
|
||||
)
|
||||
|
||||
return dummy_linear.weight
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user