Added the ability to Modify the Context Length (#210)
This commit is contained in:
@@ -57,7 +57,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path, server_args.trust_remote_code
|
||||
server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length
|
||||
)
|
||||
self.model_runner = ModelRunner(
|
||||
self.model_config,
|
||||
|
||||
@@ -11,14 +11,19 @@ class ModelConfig:
|
||||
path: str,
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.revision = revision
|
||||
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
||||
|
||||
if context_length is not None:
|
||||
self.context_len = context_length
|
||||
else:
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
# Unify the config keys for hf_config
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
self.num_attention_heads = self.hf_config.num_attention_heads
|
||||
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
||||
|
||||
@@ -546,6 +546,7 @@ class Runtime:
|
||||
trust_remote_code: bool = True,
|
||||
mem_fraction_static: float = ServerArgs.mem_fraction_static,
|
||||
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
|
||||
context_length: int = ServerArgs.context_length,
|
||||
tp_size: int = 1,
|
||||
model_mode: List[str] = (),
|
||||
schedule_heuristic: str = "lpm",
|
||||
@@ -567,6 +568,7 @@ class Runtime:
|
||||
trust_remote_code=trust_remote_code,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
max_prefill_num_token=max_prefill_num_token,
|
||||
context_length=context_length,
|
||||
tp_size=tp_size,
|
||||
model_mode=model_mode,
|
||||
schedule_heuristic=schedule_heuristic,
|
||||
|
||||
@@ -16,6 +16,7 @@ class ServerArgs:
|
||||
trust_remote_code: bool = True
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_prefill_num_token: Optional[int] = None
|
||||
context_length: Optional[int] = None
|
||||
tp_size: int = 1
|
||||
model_mode: List[str] = ()
|
||||
schedule_heuristic: str = "lpm"
|
||||
@@ -117,6 +118,12 @@ class ServerArgs:
|
||||
default=ServerArgs.max_prefill_num_token,
|
||||
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
default=ServerArgs.context_length,
|
||||
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user