From 9de9a46815bded248b01daba75936b642c2a7c06 Mon Sep 17 00:00:00 2001 From: psych0v0yager <105936906+psych0v0yager@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:22:56 -0600 Subject: [PATCH] Added the ability to Modify the Context Length (#210) --- python/sglang/srt/managers/router/model_rpc.py | 2 +- python/sglang/srt/model_config.py | 7 ++++++- python/sglang/srt/server.py | 2 ++ python/sglang/srt/server_args.py | 7 +++++++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 8f1f1e58a..f59c3f0a1 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -57,7 +57,7 @@ class ModelRpcServer(rpyc.Service): # Init model and tokenizer self.model_config = ModelConfig( - server_args.model_path, server_args.trust_remote_code + server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length ) self.model_runner = ModelRunner( self.model_config, diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index 5f8aa50ce..504f499dc 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -11,14 +11,19 @@ class ModelConfig: path: str, trust_remote_code: bool = True, revision: Optional[str] = None, + context_length: Optional[int] = None, ) -> None: self.path = path self.trust_remote_code = trust_remote_code self.revision = revision self.hf_config = get_config(self.path, trust_remote_code, revision) + + if context_length is not None: + self.context_len = context_length + else: + self.context_len = get_context_length(self.hf_config) # Unify the config keys for hf_config - self.context_len = get_context_length(self.hf_config) self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads self.num_attention_heads = self.hf_config.num_attention_heads self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 55b3ff046..9dc6e9fd4 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -546,6 +546,7 @@ class Runtime: trust_remote_code: bool = True, mem_fraction_static: float = ServerArgs.mem_fraction_static, max_prefill_num_token: int = ServerArgs.max_prefill_num_token, + context_length: int = ServerArgs.context_length, tp_size: int = 1, model_mode: List[str] = (), schedule_heuristic: str = "lpm", @@ -567,6 +568,7 @@ class Runtime: trust_remote_code=trust_remote_code, mem_fraction_static=mem_fraction_static, max_prefill_num_token=max_prefill_num_token, + context_length=context_length, tp_size=tp_size, model_mode=model_mode, schedule_heuristic=schedule_heuristic, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d6f5704d3..73583d1fa 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -16,6 +16,7 @@ class ServerArgs: trust_remote_code: bool = True mem_fraction_static: Optional[float] = None max_prefill_num_token: Optional[int] = None + context_length: Optional[int] = None tp_size: int = 1 model_mode: List[str] = () schedule_heuristic: str = "lpm" @@ -117,6 +118,12 @@ class ServerArgs: default=ServerArgs.max_prefill_num_token, help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", ) + parser.add_argument( + "--context-length", + type=int, + default=ServerArgs.context_length, + help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).", + ) parser.add_argument( "--tp-size", type=int,