From 7d1485d3765eed0ed2f55c60210dc47c6573478a Mon Sep 17 00:00:00 2001 From: Chayenne Date: Fri, 29 Nov 2024 23:36:38 -0800 Subject: [PATCH] Add get weights by parameter name for llama (#2266) --- .github/workflows/pr-test.yml | 2 + 3rdparty/amd/profiling/PROFILING.md | 2 +- python/sglang/srt/managers/io_struct.py | 11 ++ python/sglang/srt/managers/scheduler.py | 9 ++ .../sglang/srt/managers/tokenizer_manager.py | 30 ++++ python/sglang/srt/managers/tp_worker.py | 11 +- .../srt/managers/tp_worker_overlap_thread.py | 8 +- .../sglang/srt/model_executor/model_runner.py | 22 ++- python/sglang/srt/models/llama.py | 94 +++++++++++-- python/sglang/srt/server.py | 36 +++++ test/srt/run_suite.py | 1 + test/srt/test_get_parameter_by_name.py | 128 ++++++++++++++++++ 12 files changed, 337 insertions(+), 17 deletions(-) create mode 100644 test/srt/test_get_parameter_by_name.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 0d7889a5f..6a24a5747 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -128,6 +128,7 @@ jobs: python3 test_mla_fp8.py python3 test_dp_attention.py + performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner @@ -242,6 +243,7 @@ jobs: cd test/srt python3 test_eval_accuracy_large.py + accuracy-test-2-gpu: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 2-gpu-runner diff --git a/3rdparty/amd/profiling/PROFILING.md b/3rdparty/amd/profiling/PROFILING.md index 90ad8665e..79bc75b50 100644 --- a/3rdparty/amd/profiling/PROFILING.md +++ b/3rdparty/amd/profiling/PROFILING.md @@ -421,5 +421,5 @@ index 62d1ff9..6ecd78c 100644 3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container. 4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling. -======= +------- - [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5fb8c6e0e..058e930ed 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -365,6 +365,17 @@ class UpdateWeightFromDiskReqOutput: message: str +@dataclass +class GetWeightsByNameReqInput: + name: str + truncate_size: int = 100 + + +@dataclass +class GetWeightsByNameReqOutput: + parameter: list + + @dataclass class AbortReq: # The request id diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 04c18e2e0..95ac0bd0c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import ( BatchTokenIDOut, CloseSessionReqInput, FlushCacheReq, + GetWeightsByNameReqInput, + GetWeightsByNameReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -511,6 +513,9 @@ class Scheduler: self.send_to_tokenizer.send_pyobj( UpdateWeightFromDiskReqOutput(success, message) ) + elif isinstance(recv_req, GetWeightsByNameReqInput): + parameter = self.get_weights_by_name(recv_req) + self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) elif isinstance(recv_req, ProfileReq): if recv_req == ProfileReq.START_PROFILE: self.start_profile() @@ -1373,6 +1378,10 @@ class Scheduler: logger.error(message) return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): + parameter = self.tp_worker.get_weights_by_name(recv_req) + return parameter + def start_profile(self) -> None: if self.profiler is None: raise RuntimeError("Profiler is not enabled.") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9c1c591dd..630c5ec42 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, FlushCacheReq, GenerateReqInput, + GetWeightsByNameReqInput, + GetWeightsByNameReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -454,6 +456,23 @@ class TokenizerManager: else: return False, "Another update is in progress. Please try again later." + async def get_weights_by_name( + self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None + ): + if self.to_create_loop: + self.create_handle_loop() + + self.send_to_scheduler.send_pyobj(obj) + self.get_weights_by_name_result = asyncio.Future() + if self.server_args.dp_size == 1: + result = await self.get_weights_by_name_result + return result.parameter + else: + self.get_weights_by_name_tmp = [] + result = await self.get_weights_by_name_result + all_parameters = [r.parameter for r in result] + return all_parameters + async def open_session( self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None ): @@ -527,6 +546,7 @@ class TokenizerManager: BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightFromDiskReqOutput, + GetWeightsByNameReqOutput, ] = await self.recv_from_detokenizer.recv_pyobj() if isinstance(recv_obj, UpdateWeightFromDiskReqOutput): @@ -538,6 +558,16 @@ class TokenizerManager: if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) continue + elif isinstance(recv_obj, GetWeightsByNameReqOutput): + if self.server_args.dp_size == 1: + self.get_weights_by_name_result.set_result(recv_obj) + else: + self.get_weights_by_name_tmp.append(recv_obj) + if len(self.get_weights_by_name_tmp) == self.server_args.dp_size: + self.get_weights_by_name_result.set_result( + self.get_weights_by_name_tmp + ) + continue elif isinstance(recv_obj, OpenSessionReqOutput): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index bdbf58ba7..d79498c77 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -19,7 +19,10 @@ from typing import Optional from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput +from sglang.srt.managers.io_struct import ( + GetWeightsByNameReqInput, + UpdateWeightFromDiskReqInput, +) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner @@ -160,3 +163,9 @@ class TpModelWorker: recv_req.model_path, recv_req.load_format ) return success, message + + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): + parameter = self.model_runner.get_weights_by_name( + recv_req.name, recv_req.truncate_size + ) + return parameter diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 786656271..1b0be30df 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -23,7 +23,10 @@ from typing import Optional import psutil import torch -from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput +from sglang.srt.managers.io_struct import ( + GetWeightsByNameReqInput, + UpdateWeightFromDiskReqInput, +) from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs @@ -208,6 +211,9 @@ class TpModelWorkerClient: success, message = self.worker.update_weights_from_disk(recv_req) return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): + return self.worker.get_weights_by_name(recv_req) + def __delete__(self): self.input_queue.put((None, None)) self.copy_queue.put((None, None, None)) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index da311c7ec..0542b7b0b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,13 +20,10 @@ import inspect import json import logging import pkgutil -import time from functools import lru_cache -from tokenize import tabsize -from typing import Any, Optional, Type, Union +from typing import Optional, Type import torch -import torch.distributed as dist import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig @@ -403,6 +400,23 @@ class ModelRunner: logger.info("Update weights end.") return True, "Succeeded to update model weights." + def get_weights_by_name( + self, name: str, truncate_size: int = 100 + ) -> Optional[torch.Tensor]: + """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. + + Only used for unit test with an unoptimized performance. + For optimized performance, please use torch.save and torch.load. + """ + # TODO: (chenyang) Add support for Qwen models. + try: + return self.model.get_weights_by_name( + name, truncate_size, tp_size=self.tp_size + ) + except Exception as e: + logger.error(f"Error when getting parameter {name}: {e}") + return None + def init_lora_manager(self): self.lora_manager = LoRAManager( base_model=self.model, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 7e9fd0f72..5f472ef3b 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -16,6 +16,7 @@ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 """Inference-only LLaMA model compatible with HuggingFace weights.""" +import logging from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -45,6 +46,8 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import make_layers +logger = logging.getLogger(__name__) + class LlamaMLP(nn.Module): def __init__( @@ -305,6 +308,14 @@ class LlamaForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] @torch.no_grad() def forward( @@ -349,15 +360,7 @@ class LlamaForCausalLM(nn.Module): return params_mapping.get(name, name) def get_module_name_from_weight_name(self, name): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id, num_shard) - ("qkv_proj", "q_proj", "q", 3), - ("qkv_proj", "k_proj", "k", 3), - ("qkv_proj", "v_proj", "v", 3), - ("gate_up_proj", "gate_proj", 0, 2), - ("gate_up_proj", "up_proj", 1, 2), - ] - for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: if weight_name in name: return ( name.replace(weight_name, param_name)[: -len(".weight")], @@ -370,6 +373,7 @@ class LlamaForCausalLM(nn.Module): return len(params_dict) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + embed_tokens_weight = None stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -378,6 +382,7 @@ class LlamaForCausalLM(nn.Module): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + params_dict = dict(self.named_parameters()) load_tie_word_embeddings = ( @@ -425,10 +430,79 @@ class LlamaForCausalLM(nn.Module): # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing param = self.lm_head.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, embed_tokens_weight) + if embed_tokens_weight is not None: + weight_loader(param, embed_tokens_weight) apply_torchao_config_(self, params_dict, set(["proj.weight"])) + def get_weights_by_name( + self, name: str, truncate_size: int = 100, tp_size: int = 1 + ) -> Optional[torch.Tensor]: + """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. + + Only used for unit test with an unoptimized performance. + For optimized performance, please use torch.save and torch.load. + """ + try: + mapped_name = name + mapped_shard_id = None + for param_name, weight_name, shard_id in self.stacked_params_mapping: + if weight_name in name: + mapped_name = name.replace(weight_name, param_name) + mapped_shard_id = shard_id + break + params_dict = dict(self.named_parameters()) + if mapped_name in params_dict: + param = params_dict[mapped_name] + if mapped_shard_id is not None: + if mapped_shard_id in ["q", "k", "v"]: + num_heads = self.config.num_attention_heads // tp_size + num_kv_heads = self.config.num_key_value_heads // tp_size + head_dim = ( + self.config.hidden_size // self.config.num_attention_heads + ) + if mapped_shard_id == "q": + offset = 0 + size = num_heads * head_dim + elif mapped_shard_id == "k": + offset = num_heads * head_dim + size = num_kv_heads * head_dim + elif mapped_shard_id == "v": + offset = (num_heads + num_kv_heads) * head_dim + size = num_kv_heads * head_dim + weight = param.data.narrow(0, offset, size) + elif mapped_shard_id in [0, 1]: + intermediate_size = self.config.intermediate_size + hidden_size = self.config.hidden_size + slice_size = intermediate_size // tp_size + if mapped_shard_id == 0: # gate_proj + offset = 0 + size = slice_size + elif mapped_shard_id == 1: # up_proj + offset = slice_size + size = slice_size + + weight = param.data.narrow(0, offset, size) + else: + weight = param.data + else: + weight = param.data + if tp_size > 1 and ("o_proj" in name or "down_proj" in name): + gathered_weights = [ + torch.zeros_like(weight) for _ in range(tp_size) + ] + torch.distributed.all_gather(gathered_weights, weight) + weight = torch.cat(gathered_weights, dim=1) + return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] + else: + return None + + except Exception as e: + logger.error( + f"Error getting weights by name {name} in LlamaForCausalLM: {e}" + ) + return None + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 7eec7cd1f..71755654c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import ( CloseSessionReqInput, EmbeddingReqInput, GenerateReqInput, + GetWeightsByNameReqInput, OpenSessionReqInput, UpdateWeightFromDiskReqInput, ) @@ -210,6 +211,24 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R ) +@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) +async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): + """Get model parameter by name.""" + try: + ret = await tokenizer_manager.get_weights_by_name(obj, request) + if ret is None: + return ORJSONResponse( + {"error": {"message": "Get parameter by name failed"}}, + status_code=HTTPStatus.BAD_REQUEST, + ) + else: + return ORJSONResponse(ret, status_code=200) + except Exception as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + @app.api_route("/open_session", methods=["GET", "POST"]) async def open_session(obj: OpenSessionReqInput, request: Request): """Open a session, and return its unique session id.""" @@ -269,6 +288,18 @@ async def generate_request(obj: GenerateReqInput, request: Request): ) +@time_func_latency +async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request): + """Handle a get parameter by name request.""" + try: + ret = await tokenizer_manager.get_weights_by_name(obj, request) + return ret + except ValueError as e: + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + @app.api_route("/encode", methods=["POST", "PUT"]) @time_func_latency async def encode_request(obj: EmbeddingReqInput, request: Request): @@ -938,3 +969,8 @@ class Engine: async def get_server_info(self): return await _get_server_info() + + def get_weights_by_name(self, name, truncate_size=100): + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) + loop = asyncio.get_event_loop() + return loop.run_until_complete(get_weights_by_name_request(obj, None)) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index c04a1671e..d441cf9b2 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -38,6 +38,7 @@ suites = { "test_update_weights.py", "test_vision_openai_server.py", "test_session_control.py", + "test_get_parameter_by_name.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_get_parameter_by_name.py b/test/srt/test_get_parameter_by_name.py new file mode 100644 index 000000000..73b0a3f74 --- /dev/null +++ b/test/srt/test_get_parameter_by_name.py @@ -0,0 +1,128 @@ +import gc +import unittest + +import numpy as np +import requests +import torch +from transformers import AutoModelForCausalLM + +import sglang as sgl +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) +from sglang.utils import terminate_process + + +class TestUpdateWeights(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.hf_model = AutoModelForCausalLM.from_pretrained( + cls.model, torch_dtype="bfloat16" + ).to("cuda:0") + + @classmethod + def tearDownClass(cls): + del cls.hf_model + gc.collect() + torch.cuda.empty_cache() + + def init_backend(self, backend, dp, tp): + self.engine = None + self.process = None + self.backend = backend + self.dp = dp + self.tp = tp + if backend == "Engine": + self.engine = sgl.Engine( + model_path=self.model, + random_seed=42, + tp_size=self.tp, + dp_size=self.dp, + mem_fraction_static=0.85, + ) + else: + self.process = popen_launch_server( + self.model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--tp-size", + str(tp), + "--dp-size", + str(dp), + ), + ) + + def close_engine_and_server(self): + if self.engine: + self.engine.shutdown() + if self.process: + terminate_process(self.process) + + def assert_update_weights_all_close(self, param_name, truncate_size): + print( + f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}" + ) + param = self.hf_model.get_parameter(param_name)[:truncate_size] + param_np = param.cpu().detach().float().numpy() + + if self.backend == "Engine": + engine_ret = self.engine.get_weights_by_name(param_name, truncate_size) + engine_ret = self._process_return(engine_ret) + np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5) + + if self.backend == "Runtime": + runtime_ret = requests.get( + f"{self.base_url}/get_weights_by_name", + json={"name": param_name, "truncate_size": truncate_size}, + ).json() + runtime_ret = self._process_return(runtime_ret) + np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5) + + @staticmethod + def _process_return(ret): + if isinstance(ret, list) and len(ret) == 2: + print(f"running assert_allclose on data parallel") + np.testing.assert_allclose(ret[0], ret[1]) + return np.array(ret[0]) + return np.array(ret) + + def test_update_weights_unexist_model(self): + test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)] + + if torch.cuda.device_count() >= 2: + test_suits.append(("Engine", 1, 2)) + test_suits.append(("Runtime", 2, 1)) + + if torch.cuda.device_count() >= 4: + test_suits.extend([("Engine", 2, 2), ("Runtime", 2, 2)]) + + parameters = [ + "model.embed_tokens.weight", + "model.layers.0.input_layernorm.weight", + "model.layers.1.self_attn.q_proj.weight", + "model.layers.2.self_attn.k_proj.weight", + "model.layers.3.self_attn.v_proj.weight", + "model.layers.4.self_attn.o_proj.weight", + "model.layers.5.mlp.gate_proj.weight", + "model.layers.6.mlp.up_proj.weight", + "model.layers.7.mlp.down_proj.weight", + "model.layers.8.post_attention_layernorm.weight", + "model.norm.weight", + "lm_head.weight", + ] + + for test_suit in test_suits: + self.init_backend(*test_suit) + for param_name in parameters: + self.assert_update_weights_all_close(param_name, 100) + self.close_engine_and_server() + + +if __name__ == "__main__": + unittest.main()