Fix gemlite import (#2553)
This commit is contained in:
@@ -322,18 +322,6 @@ def throughput_test(
|
|||||||
)
|
)
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
import pwd
|
|
||||||
|
|
||||||
from gemlite.core import GemLiteLinearTriton
|
|
||||||
|
|
||||||
GemLiteLinearTriton.cache_config(
|
|
||||||
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logging.info("\nBenchmark...")
|
logging.info("\nBenchmark...")
|
||||||
result = throughput_test_once(
|
result = throughput_test_once(
|
||||||
backend_name=bench_args.backend,
|
backend_name=bench_args.backend,
|
||||||
|
|||||||
@@ -386,18 +386,6 @@ def latency_test(
|
|||||||
server_args.device,
|
server_args.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
import pwd
|
|
||||||
|
|
||||||
from gemlite.core import GemLiteLinearTriton
|
|
||||||
|
|
||||||
GemLiteLinearTriton.cache_config(
|
|
||||||
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
rank_print("Benchmark ...")
|
rank_print("Benchmark ...")
|
||||||
|
|
||||||
# Run the sweep
|
# Run the sweep
|
||||||
|
|||||||
@@ -2,8 +2,14 @@
|
|||||||
Common utilities for torchao.
|
Common utilities for torchao.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pwd
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def apply_torchao_config_to_model(
|
def apply_torchao_config_to_model(
|
||||||
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
||||||
@@ -50,27 +56,17 @@ def apply_torchao_config_to_model(
|
|||||||
elif "gemlite" in torchao_config:
|
elif "gemlite" in torchao_config:
|
||||||
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
|
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
|
||||||
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
|
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
|
||||||
import os
|
from gemlite.core import GemLiteLinearTriton
|
||||||
import pwd
|
from torchao.quantization import gemlite_uintx_weight_only
|
||||||
|
|
||||||
import gemlite
|
|
||||||
from gemlite.core import GemLiteLinearTriton, set_autotune
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torchao.quantization import gemlite_uintx_weight_only
|
|
||||||
except:
|
|
||||||
print(
|
|
||||||
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
_quant_args = torchao_config.split("-")
|
_quant_args = torchao_config.split("-")
|
||||||
bit_width = int(_quant_args[-2])
|
bit_width = int(_quant_args[-2])
|
||||||
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
|
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
packing_bitwidth = int(_quant_args[-3])
|
packing_bitwidth = int(_quant_args[-3])
|
||||||
except:
|
except (ValueError, IndexError):
|
||||||
# if only 2 inputs found, use default value
|
# if only 2 inputs found or conversion fails, use default value
|
||||||
packing_bitwidth = 32
|
packing_bitwidth = 32
|
||||||
|
|
||||||
quantize_(
|
quantize_(
|
||||||
|
|||||||
Reference in New Issue
Block a user