diff --git a/python/pyproject.toml b/python/pyproject.toml index daf09ea25..1389822a3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", "psutil", "pydantic", "python-multipart", - "torch", "uvicorn", "uvloop", "zmq", + "torch", "torchao", "uvicorn", "uvloop", "zmq", "vllm==0.5.5", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py new file mode 100644 index 000000000..16eb1f2c5 --- /dev/null +++ b/python/sglang/srt/layers/torchao_utils.py @@ -0,0 +1,36 @@ +""" +Common utilities for torchao. +""" + +import torch +from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, +) + + +def torchao_quantize_param_data(param, torchao_config): + dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) + dummy_linear.weight = param + if "int8wo" in torchao_config: + quantize_(dummy_linear, int8_weight_only()) + elif "int8dq" in torchao_config: + quantize_(dummy_linear, int8_dynamic_activation_int8_weight()) + elif "int4wo" in torchao_config: + group_size = int(torchao_config.split("-")[-1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" + quantize_(dummy_linear, int4_weight_only(group_size=group_size)) + elif "fp8wo" in torchao_config: + from torchao.quantization import float8_weight_only + + # this requires newer hardware + # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 + quantize_(dummy_linear, float8_weight_only()) + return dummy_linear.weight diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9c82b2a81..78f99dcd6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -97,6 +97,7 @@ class ModelRunner: "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "enable_mla": server_args.enable_mla, + "torchao_config": server_args.torchao_config, } ) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 926d87db8..ac53712fc 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -42,6 +42,8 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.torchao_utils import torchao_quantize_param_data +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -299,6 +301,7 @@ class LlamaForCausalLM(nn.Module): super().__init__() self.config = config self.quant_config = quant_config + self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -361,6 +364,25 @@ class LlamaForCausalLM(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if self.torchao_config: + if name.endswith("proj.weight") and param.ndim == 2: + params_dict[name] = torchao_quantize_param_data( + param, self.torchao_config + ) + + if self.torchao_config: + # quantizing the loaded, stacked params, e.g. "...qkv_proj" + stacked_params = set(entry[0] for entry in stacked_params_mapping) + for param_suffix in stacked_params: + for name in params_dict: + if param_suffix in name: + param = params_dict[name] + params_dict[name] = torchao_quantize_param_data( + param, self.torchao_config + ) + + self.load_state_dict(params_dict, assign=True) + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 14dd63b5a..3dfb1dc41 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -95,6 +95,7 @@ class ServerArgs: disable_custom_all_reduce: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False + torchao_config: str = "" enable_p2p_check: bool = False enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False @@ -443,7 +444,13 @@ class ServerArgs: parser.add_argument( "--enable-torch-compile", action="store_true", - help="Optimize the model with torch.compile, experimental feature.", + help="Optimize the model with torch.compile. Experimental feature.", + ) + parser.add_argument( + "--torchao-config", + type=str, + default=ServerArgs.torchao_config, + help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo", ) parser.add_argument( "--enable-p2p-check", diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index 25aa0ca11..6ddd97d94 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -29,12 +29,12 @@ class TestEvalAccuracyMini(unittest.TestCase): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 if __name__ == "__main__": diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index b15308dce..b6027b61c 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -42,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.62, f"{metrics}" + assert metrics["score"] >= 0.625, f"{metrics}" def test_human_eval(self): args = SimpleNamespace( @@ -54,7 +54,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.42, f"{metrics}" + assert metrics["score"] >= 0.425, f"{metrics}" def test_mgsm_en(self): args = SimpleNamespace( @@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.62, f"{metrics}" + assert metrics["score"] >= 0.625, f"{metrics}" if __name__ == "__main__": diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index e8cafa15d..40f47d6b6 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -22,7 +22,7 @@ class TestTorchCompile(unittest.TestCase): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile", "--disable-radix-cache"], + other_args=["--enable-torch-compile"], ) @classmethod @@ -34,12 +34,12 @@ class TestTorchCompile(unittest.TestCase): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 def run_decode(self, max_new_tokens): response = requests.post( diff --git a/test/srt/test_torchao.py b/test/srt/test_torchao.py new file mode 100644 index 000000000..d2084e7d5 --- /dev/null +++ b/test/srt/test_torchao.py @@ -0,0 +1,73 @@ +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestTorchCompile(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--torchao-config", "int4wo-128"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + import time + + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 210 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index a94ca9212..b3f65ac13 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -32,12 +32,12 @@ class TestTritonAttnBackend(unittest.TestCase): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 if __name__ == "__main__":