From 27b557aea794d267e371d3bdaa4722a4db45a1e1 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 16 Sep 2024 18:16:27 -0700 Subject: [PATCH] Clean up model loader (#1440) --- python/sglang/srt/managers/tp_worker.py | 4 +- .../sglang/srt/model_executor/model_runner.py | 41 +++++++------- python/sglang/srt/server.py | 2 +- python/sglang/srt/utils.py | 53 +------------------ test/srt/test_update_weights.py | 13 +++-- 5 files changed, 33 insertions(+), 80 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fe9017f12..8053f27d0 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -415,7 +415,7 @@ class ModelTpServer: # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: - logger.warn( + logger.warning( "Request length is longer than the KV cache pool size or " "the max context length. Truncated!!!" ) @@ -936,6 +936,8 @@ class ModelTpServer: if success: flash_cache_success = self.flush_cache() assert flash_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) return success, message diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b754e41c7..f4cdb77ba 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, is_generation_model, - is_llama3_405b_fp8_head_16, is_multimodal_model, monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, - monkey_patch_vllm_qvk_linear_loader, ) logger = logging.getLogger(__name__) @@ -166,10 +164,13 @@ class ModelRunner: return min_per_gpu_memory def load_model(self): - torch.set_num_threads(1) logger.info( f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) + + # This can reduce thread conflicts and speed up weight loading. + torch.set_num_threads(1) + if torch.cuda.get_device_capability()[0] < 8: logger.info( "Compute capability below sm80. Use float16 due to lack of bfloat16 support." @@ -178,6 +179,7 @@ class ModelRunner: if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") + # Prepare the vllm model config monkey_patch_vllm_dummy_weight_loader() self.device_config = DeviceConfig() self.load_config = LoadConfig(load_format=self.server_args.load_format) @@ -188,23 +190,16 @@ class ModelRunner: tokenizer_mode=None, trust_remote_code=self.server_args.trust_remote_code, dtype=self.server_args.dtype, - seed=42, + seed=self.server_args.random_seed, skip_tokenizer_init=True, ) - - # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints - # Drop this after Sept, 2024. - if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: - self.model_config.hf_config.num_key_value_heads = 8 - self.vllm_model_config.hf_config.num_key_value_heads = 8 - monkey_patch_vllm_qvk_linear_loader() - - self.dtype = self.vllm_model_config.dtype if self.model_config.model_override_args is not None: self.vllm_model_config.hf_config.update( self.model_config.model_override_args ) + self.dtype = self.vllm_model_config.dtype + # Load the model self.model = get_model( model_config=self.vllm_model_config, load_config=self.load_config, @@ -255,20 +250,20 @@ class ModelRunner: tokenizer_mode=None, trust_remote_code=self.server_args.trust_remote_code, dtype=self.server_args.dtype, - seed=42, + seed=self.server_args.random_seed, skip_tokenizer_init=True, ) except Exception as e: - logger.error(f"Failed to load model config: {e}") - return False, "Failed to update model weights" + message = f"Failed to load model config: {e}." + return False, message load_config = LoadConfig(load_format=load_format) # Only support vllm DefaultModelLoader for now loader = get_model_loader(load_config) if not isinstance(loader, DefaultModelLoader): - logger.error("Failed to get weights iterator: Unsupported loader") - return False, "Failed to update model weights" + message = f"Failed to get model loader: {loader}." + return False, message def get_weight_iter(config): iter = loader._get_weights_iterator( @@ -293,14 +288,14 @@ class ModelRunner: try: iter = get_weight_iter(vllm_model_config) except Exception as e: - message = f"Failed to get weights iterator: {e}" - logger.error(message) + message = f"Failed to get weights iterator: {e}." return False, message try: model = model_load_weights(self.model, iter) except Exception as e: - message = f"Failed to update weights: {e}. \n Rolling back to original weights" - logger.error(message) + message = ( + f"Failed to update weights: {e}.\nRolling back to original weights." + ) del iter gc.collect() iter = get_weight_iter(self.vllm_model_config) @@ -315,7 +310,7 @@ class ModelRunner: self.model_config.path = model_path logger.info("Update weights end.") - return True, "Succeeded to update model weights" + return True, "Succeeded to update model weights." def init_lora_manager(self): self.lora_manager = LoRAManager( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d2a248a92..9749075d0 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -152,7 +152,7 @@ async def flush_cache(): async def update_weights(obj: UpdateWeightReqInput, request: Request): success, message = await tokenizer_manager.update_weights(obj, request) - content = {"message": message, "success": str(success)} + content = {"success": success, "message": message} if success: return JSONResponse( content, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 125bb556f..1f1a44870 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -187,7 +187,7 @@ def allocate_init_ports( cur_port += 1 if port is not None and ret_ports[0] != port: - logger.warn( + logger.warning( f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead." ) @@ -623,56 +623,7 @@ def set_ulimit(target_soft_limit=65535): try: resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: - logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") - - -def is_llama3_405b_fp8_head_16(model_config): - """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads.""" - if ( - model_config.hf_config.architectures[0] == "LlamaForCausalLM" - and model_config.hf_config.hidden_size == 16384 - and model_config.hf_config.intermediate_size == 53248 - and model_config.hf_config.num_hidden_layers == 126 - and model_config.hf_config.num_key_value_heads == 16 - and hasattr(model_config.hf_config, "quantization_config") - and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8" - ): - return True - return False - - -def monkey_patch_vllm_qvk_linear_loader(): - """A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints.""" - from vllm.model_executor.layers.linear import QKVParallelLinear - - origin_weight_loader = QKVParallelLinear.weight_loader - - def get_original_weight(loaded_weight, head_dim): - n_kv_head = loaded_weight.shape[0] // (2 * head_dim) - dim = loaded_weight.shape[1] - for i in range(n_kv_head): - loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[ - 2 * i * head_dim : (2 * i + 1) * head_dim, : - ] - original_kv_weight = loaded_weight[: n_kv_head * head_dim, :] - assert original_kv_weight.shape == (n_kv_head * head_dim, dim) - return original_kv_weight - - def weight_loader_srt( - self, - param: Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, - ): - if ( - loaded_shard_id in ["k", "v"] - and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2 - ): - loaded_weight = get_original_weight(loaded_weight, self.head_size) - - origin_weight_loader(self, param, loaded_weight, loaded_shard_id) - - setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) + logger.warning(f"Fail to set RLIMIT_NOFILE: {e}") def add_api_key_middleware(app, api_key: str): diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py index 7b8404c73..3f0bbd7e2 100644 --- a/test/srt/test_update_weights.py +++ b/test/srt/test_update_weights.py @@ -44,7 +44,6 @@ class TestReplaceWeights(unittest.TestCase): ) print(json.dumps(response.json())) print("=" * 100) - # return the "text" in response text = response.json()["text"] return text @@ -61,7 +60,9 @@ class TestReplaceWeights(unittest.TestCase): "model_path": model_path, }, ) + ret = response.json() print(json.dumps(response.json())) + return ret def test_replace_weights(self): origin_model_path = self.get_model_info() @@ -70,7 +71,8 @@ class TestReplaceWeights(unittest.TestCase): # update weights new_model_path = "meta-llama/Meta-Llama-3.1-8B" - self.run_update_weights(new_model_path) + ret = self.run_update_weights(new_model_path) + assert ret["success"] updated_model_path = self.get_model_info() print(f"updated_model_path: {updated_model_path}") @@ -81,7 +83,9 @@ class TestReplaceWeights(unittest.TestCase): assert origin_response[:32] != updated_response[:32] # update weights back - self.run_update_weights(origin_model_path) + ret = self.run_update_weights(origin_model_path) + assert ret["success"] + updated_model_path = self.get_model_info() assert updated_model_path == origin_model_path @@ -95,7 +99,8 @@ class TestReplaceWeights(unittest.TestCase): # update weights new_model_path = "meta-llama/Meta-Llama-3.1-8B-1" - self.run_update_weights(new_model_path) + ret = self.run_update_weights(new_model_path) + assert not ret["success"] updated_model_path = self.get_model_info() print(f"updated_model_path: {updated_model_path}")