Add torchao quant (int4/int8/fp8) to llama models (#1341)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
Jerry Zhang
2024-09-09 05:32:41 -07:00
committed by GitHub
parent e4d68afcf0
commit a7c47e0f02
10 changed files with 151 additions and 12 deletions

View File

@@ -22,7 +22,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "uvicorn", "uvloop", "zmq",
"torch", "torchao", "uvicorn", "uvloop", "zmq",
"vllm==0.5.5", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]

View File

@@ -0,0 +1,36 @@
"""
Common utilities for torchao.
"""
import torch
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
def torchao_quantize_param_data(param, torchao_config):
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
if "int8wo" in torchao_config:
quantize_(dummy_linear, int8_weight_only())
elif "int8dq" in torchao_config:
quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
elif "int4wo" in torchao_config:
group_size = int(torchao_config.split("-")[-1])
assert group_size in [
32,
64,
128,
256,
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(dummy_linear, int4_weight_only(group_size=group_size))
elif "fp8wo" in torchao_config:
from torchao.quantization import float8_weight_only
# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(dummy_linear, float8_weight_only())
return dummy_linear.weight

View File

@@ -97,6 +97,7 @@ class ModelRunner:
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
"torchao_config": server_args.torchao_config,
}
)

View File

@@ -42,6 +42,8 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -299,6 +301,7 @@ class LlamaForCausalLM(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@@ -361,6 +364,25 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if self.torchao_config:
if name.endswith("proj.weight") and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
if self.torchao_config:
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
stacked_params = set(entry[0] for entry in stacked_params_mapping)
for param_suffix in stacked_params:
for name in params_dict:
if param_suffix in name:
param = params_dict[name]
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
self.load_state_dict(params_dict, assign=True)
class Phi3ForCausalLM(LlamaForCausalLM):
pass

View File

@@ -95,6 +95,7 @@ class ServerArgs:
disable_custom_all_reduce: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
torchao_config: str = ""
enable_p2p_check: bool = False
enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False
@@ -443,7 +444,13 @@ class ServerArgs:
parser.add_argument(
"--enable-torch-compile",
action="store_true",
help="Optimize the model with torch.compile, experimental feature.",
help="Optimize the model with torch.compile. Experimental feature.",
)
parser.add_argument(
"--torchao-config",
type=str,
default=ServerArgs.torchao_config,
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
)
parser.add_argument(
"--enable-p2p-check",