From 38625e2139941fe8a02db81ebdd2babda359f05b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 17 Nov 2024 15:48:12 -0800 Subject: [PATCH] Remove monkey_patch_vllm_dummy_weight_loader (#2064) --- python/sglang/srt/managers/scheduler.py | 4 +- .../srt/managers/tp_worker_overlap_thread.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 4 +- python/sglang/srt/utils.py | 51 ------------------- test/srt/test_bench_latency.py | 4 +- test/srt/test_bench_serving.py | 22 ++++---- 6 files changed, 17 insertions(+), 70 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2bdf4cda7..bb97efe2e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -895,7 +895,7 @@ class Scheduler: logits_output, next_token_ids, bid = result if self.enable_overlap: - logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid) + logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) else: # Move next_token_ids and logprobs to cpu if batch.return_logprob: @@ -970,7 +970,7 @@ class Scheduler: self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: - logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid) + logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) next_token_logprobs = logits_output.next_token_logprobs else: # Move next_token_ids and logprobs to cpu diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 8c924c442..3ae1e37b3 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -141,7 +141,7 @@ class TpModelWorkerClient: self.launch_event.set() self.output_queue.put((copy_event, logits_output, next_token_ids)) - def resulve_batch_result(self, bid: int): + def resolve_batch_result(self, bid: int): copy_event, logits_output, next_token_ids = self.output_queue.get() while not copy_event.query(): time.sleep(1e-5) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5cde1e942..55bf9afd8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -58,7 +58,6 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, - monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, ) @@ -242,7 +241,6 @@ class ModelRunner: raise RuntimeError("SGLang only supports sm75 and above.") # Prepare the vllm model config - monkey_patch_vllm_dummy_weight_loader() self.load_config = LoadConfig( load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, @@ -261,7 +259,6 @@ class ModelRunner: self.vllm_model_config.hf_config.update( self.model_config.model_override_args ) - self.dtype = self.vllm_model_config.dtype # Load the model self.model = get_model( @@ -278,6 +275,7 @@ class ModelRunner: if hasattr(self.model, "get_attention_sliding_window_size") else None ) + self.dtype = self.vllm_model_config.dtype logger.info( f"Load weight end. " diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 32317ec2e..7e6174ad8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -405,57 +405,6 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int): setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) -def monkey_patch_vllm_dummy_weight_loader(): - """ - Monkey patch the dummy weight loader in vllm to call process_weights_after_loading. - """ - - from vllm.model_executor.model_loader.loader import ( - CacheConfig, - DeviceConfig, - DummyModelLoader, - LoRAConfig, - ModelConfig, - ParallelConfig, - SchedulerConfig, - _initialize_model, - initialize_dummy_weights, - nn, - set_default_torch_dtype, - ) - - def load_model( - self, - *, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model( - model_config, - self.load_config, - lora_config, - cache_config, - ) - - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - return model.eval() - - setattr(DummyModelLoader, "load_model", load_model) - - vllm_all_gather_backup = None diff --git a/test/srt/test_bench_latency.py b/test/srt/test_bench_latency.py index fa6b8e2fa..e54f49088 100644 --- a/test/srt/test_bench_latency.py +++ b/test/srt/test_bench_latency.py @@ -13,7 +13,7 @@ class TestBenchLatency(unittest.TestCase): output_throughput = run_bench_latency(DEFAULT_MODEL_NAME_FOR_TEST, []) if is_in_ci(): - assert output_throughput > 130, f"{output_throughput=}" + self.assertGreater(output_throughput, 135) def test_moe_default(self): output_throughput = run_bench_latency( @@ -21,7 +21,7 @@ class TestBenchLatency(unittest.TestCase): ) if is_in_ci(): - assert output_throughput > 125, f"{output_throughput=}" + self.assertGreater(output_throughput, 125) if __name__ == "__main__": diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 6955d4917..c3c6a7d13 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["output_throughput"] > 2830 + self.assertGreater(res["output_throughput"], 2850) def test_offline_throughput_non_stream_small_batch_size(self): res = run_bench_serving( @@ -35,7 +35,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["output_throughput"] > 1000 + self.assertGreater(res["output_throughput"], 950) def test_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -46,7 +46,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["output_throughput"] > 2880 + self.assertGreater(res["output_throughput"], 2900) def test_offline_throughput_without_chunked_prefill(self): res = run_bench_serving( @@ -57,7 +57,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["output_throughput"] > 2600 + self.assertGreater(res["output_throughput"], 2600) def test_offline_throughput_with_triton_attention_backend(self): res = run_bench_serving( @@ -73,7 +73,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["output_throughput"] > 2930 + self.assertGreater(res["output_throughput"], 2950) def test_offline_throughput_default_fp8(self): res = run_bench_serving( @@ -84,7 +84,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["output_throughput"] > 3100 + self.assertGreater(res["output_throughput"], 3200) def test_online_latency_default(self): res = run_bench_serving( @@ -95,9 +95,9 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["median_e2e_latency_ms"] < 12000 - assert res["median_ttft_ms"] < 80 - assert res["median_itl_ms"] < 12 + self.assertLess(res["median_e2e_latency_ms"], 12000) + self.assertLess(res["median_ttft_ms"], 80) + self.assertLess(res["median_itl_ms"], 11) def test_moe_offline_throughput_default(self): res = run_bench_serving( @@ -108,7 +108,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["output_throughput"] > 1850 + self.assertGreater(res["output_throughput"], 1900) def test_moe_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -119,7 +119,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - assert res["output_throughput"] > 1950 + self.assertGreater(res["output_throughput"], 1950) if __name__ == "__main__":