Add float8 dynamic quant to torchao_utils (#1528)
This commit is contained in:
@@ -18,11 +18,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|||||||
"""
|
"""
|
||||||
# Lazy import to suppress some warnings
|
# Lazy import to suppress some warnings
|
||||||
from torchao.quantization import (
|
from torchao.quantization import (
|
||||||
|
float8_dynamic_activation_float8_weight,
|
||||||
int4_weight_only,
|
int4_weight_only,
|
||||||
int8_dynamic_activation_int8_weight,
|
int8_dynamic_activation_int8_weight,
|
||||||
int8_weight_only,
|
int8_weight_only,
|
||||||
quantize_,
|
quantize_,
|
||||||
)
|
)
|
||||||
|
from torchao.quantization.observer import PerRow, PerTensor
|
||||||
|
|
||||||
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
||||||
dummy_linear.weight = param
|
dummy_linear.weight = param
|
||||||
@@ -45,6 +47,22 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|||||||
# this requires newer hardware
|
# this requires newer hardware
|
||||||
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||||
quantize_(dummy_linear, float8_weight_only())
|
quantize_(dummy_linear, float8_weight_only())
|
||||||
|
elif "fp8dq" in torchao_config:
|
||||||
|
granularity = torchao_config.split("-")[-1]
|
||||||
|
GRANULARITY_MAP = {
|
||||||
|
"per_row": PerRow(),
|
||||||
|
"per_tensor": PerTensor(),
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
granularity in GRANULARITY_MAP
|
||||||
|
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
|
||||||
|
quantize_(
|
||||||
|
dummy_linear,
|
||||||
|
float8_dynamic_activation_float8_weight(
|
||||||
|
granularity=GRANULARITY_MAP[granularity]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
return dummy_linear.weight
|
return dummy_linear.weight
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user