Add integration with gemlite weight only quant (#2528)
This commit is contained in:
@@ -21,7 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
|
|||||||
"orjson", "outlines>=0.0.44,<0.1.0",
|
"orjson", "outlines>=0.0.44,<0.1.0",
|
||||||
"packaging", "pillow", "prometheus-client>=0.20.0",
|
"packaging", "pillow", "prometheus-client>=0.20.0",
|
||||||
"psutil", "pydantic", "python-multipart",
|
"psutil", "pydantic", "python-multipart",
|
||||||
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
|
"pyzmq>=25.1.2", "torchao>=0.7.0", "gemlite", "uvicorn", "uvloop",
|
||||||
"xgrammar>=0.1.6"]
|
"xgrammar>=0.1.6"]
|
||||||
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"]
|
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"]
|
||||||
|
|
||||||
|
|||||||
@@ -322,6 +322,18 @@ def throughput_test(
|
|||||||
)
|
)
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
import pwd
|
||||||
|
|
||||||
|
from gemlite.core import GemLiteLinearTriton
|
||||||
|
|
||||||
|
GemLiteLinearTriton.cache_config(
|
||||||
|
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
logging.info("\nBenchmark...")
|
logging.info("\nBenchmark...")
|
||||||
result = throughput_test_once(
|
result = throughput_test_once(
|
||||||
backend_name=bench_args.backend,
|
backend_name=bench_args.backend,
|
||||||
|
|||||||
@@ -385,6 +385,19 @@ def latency_test(
|
|||||||
8, # shorter decoding to speed up the warmup
|
8, # shorter decoding to speed up the warmup
|
||||||
server_args.device,
|
server_args.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
import pwd
|
||||||
|
|
||||||
|
from gemlite.core import GemLiteLinearTriton
|
||||||
|
|
||||||
|
GemLiteLinearTriton.cache_config(
|
||||||
|
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
rank_print("Benchmark ...")
|
rank_print("Benchmark ...")
|
||||||
|
|
||||||
# Run the sweep
|
# Run the sweep
|
||||||
|
|||||||
@@ -47,6 +47,41 @@ def apply_torchao_config_to_model(
|
|||||||
256,
|
256,
|
||||||
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
||||||
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
||||||
|
elif "gemlite" in torchao_config:
|
||||||
|
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
|
||||||
|
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
|
||||||
|
import os
|
||||||
|
import pwd
|
||||||
|
|
||||||
|
import gemlite
|
||||||
|
from gemlite.core import GemLiteLinearTriton, set_autotune
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchao.quantization import gemlite_uintx_weight_only
|
||||||
|
except:
|
||||||
|
print(
|
||||||
|
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
_quant_args = torchao_config.split("-")
|
||||||
|
bit_width = int(_quant_args[-2])
|
||||||
|
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
|
||||||
|
try:
|
||||||
|
packing_bitwidth = int(_quant_args[-3])
|
||||||
|
except:
|
||||||
|
# if only 2 inputs found, use default value
|
||||||
|
packing_bitwidth = 32
|
||||||
|
|
||||||
|
quantize_(
|
||||||
|
model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
|
||||||
|
)
|
||||||
|
|
||||||
|
# try to load gemlite kernel config
|
||||||
|
GemLiteLinearTriton.load_config(
|
||||||
|
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
|
||||||
|
)
|
||||||
|
|
||||||
elif "fp8wo" in torchao_config:
|
elif "fp8wo" in torchao_config:
|
||||||
# this requires newer hardware
|
# this requires newer hardware
|
||||||
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||||
|
|||||||
Reference in New Issue
Block a user