Add max_prefill_num_token into server arguments (#133)

This commit is contained in:
Ying Sheng
2024-02-03 02:35:54 -08:00
committed by GitHub
parent 67be11c790
commit e095b16236
3 changed files with 12 additions and 2 deletions

View File

@@ -82,7 +82,8 @@ class ModelRpcServer(rpyc.Service):
self.max_total_num_token = self.model_runner.max_total_num_token
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len, self.max_total_num_token // 6
self.model_config.context_len,
self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token,
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)

View File

@@ -430,7 +430,8 @@ class Runtime:
load_format: str = "auto",
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
mem_fraction_static: float = 0.9,
mem_fraction_static: float = ServerArgs.mem_fraction_static,
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
tp_size: int = 1,
model_mode: List[str] = (),
schedule_heuristic: str = "lpm",
@@ -451,6 +452,7 @@ class Runtime:
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
mem_fraction_static=mem_fraction_static,
max_prefill_num_token=max_prefill_num_token,
tp_size=tp_size,
model_mode=model_mode,
schedule_heuristic=schedule_heuristic,

View File

@@ -15,6 +15,7 @@ class ServerArgs:
chat_template: Optional[str] = None
trust_remote_code: bool = True
mem_fraction_static: Optional[float] = None
max_prefill_num_token: Optional[int] = None
tp_size: int = 1
model_mode: List[str] = ()
schedule_heuristic: str = "lpm"
@@ -109,6 +110,12 @@ class ServerArgs:
default=ServerArgs.mem_fraction_static,
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
)
parser.add_argument(
"--max-prefill-num-token",
type=int,
default=ServerArgs.max_prefill_num_token,
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
)
parser.add_argument(
"--tp-size",
type=int,