[Minor] improve kill scripts and torchao import (#1375)

This commit is contained in:
Lianmin Zheng
2024-09-10 11:27:03 -07:00
committed by GitHub
parent dff2860a69
commit 6c7cb90365
2 changed files with 9 additions and 7 deletions

View File

@@ -3,15 +3,17 @@ Common utilities for torchao.
"""
import torch
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
def torchao_quantize_param_data(param, torchao_config):
# Lazy import to suppress some warnings
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
if "int8wo" in torchao_config: