Chunked prefill support (#797)
This commit is contained in:
@@ -38,24 +38,24 @@ class ScheduleHeuristic:
|
|||||||
self.max_total_num_tokens = max_total_num_tokens
|
self.max_total_num_tokens = max_total_num_tokens
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
|
|
||||||
def get_priority_queue(self, forward_queue):
|
def get_priority_queue(self, waiting_queue):
|
||||||
if self.schedule_heuristic == "lpm":
|
if self.schedule_heuristic == "lpm":
|
||||||
# longest prefix match
|
# longest prefix match
|
||||||
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
|
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
||||||
return forward_queue
|
return waiting_queue
|
||||||
elif self.schedule_heuristic == "fcfs":
|
elif self.schedule_heuristic == "fcfs":
|
||||||
# first come first serve
|
# first come first serve
|
||||||
return forward_queue
|
return waiting_queue
|
||||||
elif self.schedule_heuristic == "lof":
|
elif self.schedule_heuristic == "lof":
|
||||||
# longest output first
|
# longest output first
|
||||||
forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
||||||
return forward_queue
|
return waiting_queue
|
||||||
elif self.schedule_heuristic == "random":
|
elif self.schedule_heuristic == "random":
|
||||||
random.shuffle(forward_queue)
|
random.shuffle(waiting_queue)
|
||||||
return forward_queue
|
return waiting_queue
|
||||||
elif self.schedule_heuristic == "dfs-weight":
|
elif self.schedule_heuristic == "dfs-weight":
|
||||||
last_node_to_reqs = defaultdict(list)
|
last_node_to_reqs = defaultdict(list)
|
||||||
for req in forward_queue:
|
for req in waiting_queue:
|
||||||
last_node_to_reqs[req.last_node].append(req)
|
last_node_to_reqs[req.last_node].append(req)
|
||||||
|
|
||||||
node_to_weight = defaultdict(int)
|
node_to_weight = defaultdict(int)
|
||||||
@@ -67,7 +67,7 @@ class ScheduleHeuristic:
|
|||||||
self.get_dfs_priority(
|
self.get_dfs_priority(
|
||||||
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
||||||
)
|
)
|
||||||
assert len(q) == len(forward_queue)
|
assert len(q) == len(waiting_queue)
|
||||||
return q
|
return q
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
||||||
|
|||||||
@@ -77,6 +77,10 @@ class ModelTpServer:
|
|||||||
self.schedule_heuristic = server_args.schedule_heuristic
|
self.schedule_heuristic = server_args.schedule_heuristic
|
||||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||||
|
|
||||||
|
# Chunked prefill
|
||||||
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||||
|
self.current_inflight_req = None
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
@@ -157,7 +161,7 @@ class ModelTpServer:
|
|||||||
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||||
|
|
||||||
# Init running status
|
# Init running status
|
||||||
self.forward_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
self.running_batch: Batch = None
|
self.running_batch: Batch = None
|
||||||
self.out_pyobjs = []
|
self.out_pyobjs = []
|
||||||
self.decode_forward_ct = 0
|
self.decode_forward_ct = 0
|
||||||
@@ -220,6 +224,7 @@ class ModelTpServer:
|
|||||||
# Run a new prefill batch
|
# Run a new prefill batch
|
||||||
self.forward_prefill_batch(new_batch)
|
self.forward_prefill_batch(new_batch)
|
||||||
self.cache_filled_batch(new_batch)
|
self.cache_filled_batch(new_batch)
|
||||||
|
self.filter_out_inflight(new_batch)
|
||||||
|
|
||||||
if not new_batch.is_empty():
|
if not new_batch.is_empty():
|
||||||
if self.running_batch is None:
|
if self.running_batch is None:
|
||||||
@@ -261,7 +266,7 @@ class ModelTpServer:
|
|||||||
f"#token: {num_used}, "
|
f"#token: {num_used}, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"gen throughput (token/s): {throughput:.2f}, "
|
f"gen throughput (token/s): {throughput:.2f}, "
|
||||||
f"#queue-req: {len(self.forward_queue)}"
|
f"#queue-req: {len(self.waiting_queue)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_memory(self):
|
def check_memory(self):
|
||||||
@@ -328,9 +333,10 @@ class ModelTpServer:
|
|||||||
),
|
),
|
||||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||||
)
|
)
|
||||||
self.forward_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||||
|
# TODO(lsyin): organize this function
|
||||||
running_bs = (
|
running_bs = (
|
||||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||||
)
|
)
|
||||||
@@ -338,7 +344,7 @@ class ModelTpServer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Compute matched prefix length
|
# Compute matched prefix length
|
||||||
for req in self.forward_queue:
|
for req in self.waiting_queue:
|
||||||
req.input_ids = req.origin_input_ids + req.output_ids
|
req.input_ids = req.origin_input_ids + req.output_ids
|
||||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
@@ -348,7 +354,7 @@ class ModelTpServer:
|
|||||||
req.last_node = last_node
|
req.last_node = last_node
|
||||||
|
|
||||||
# Get priority queue
|
# Get priority queue
|
||||||
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
||||||
|
|
||||||
# Add requests if there is available space
|
# Add requests if there is available space
|
||||||
can_run_list = []
|
can_run_list = []
|
||||||
@@ -367,7 +373,33 @@ class ModelTpServer:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
for req in self.forward_queue:
|
# Handle the current inflight request
|
||||||
|
take_inflight = 0
|
||||||
|
if self.current_inflight_req:
|
||||||
|
take_inflight = 1
|
||||||
|
r = self.current_inflight_req
|
||||||
|
r.input_ids = r.origin_input_ids + r.output_ids
|
||||||
|
truncated = (
|
||||||
|
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
||||||
|
)
|
||||||
|
r.extend_input_len = min(
|
||||||
|
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
|
||||||
|
)
|
||||||
|
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
||||||
|
can_run_list.append(r)
|
||||||
|
|
||||||
|
if not truncated:
|
||||||
|
# Finish inflight
|
||||||
|
self.current_inflight_req = None
|
||||||
|
new_batch_total_tokens += (
|
||||||
|
r.extend_input_len + r.sampling_params.max_new_tokens
|
||||||
|
)
|
||||||
|
new_batch_input_tokens += r.extend_input_len
|
||||||
|
else:
|
||||||
|
new_batch_total_tokens += r.extend_input_len
|
||||||
|
new_batch_input_tokens += r.extend_input_len
|
||||||
|
|
||||||
|
for req in self.waiting_queue:
|
||||||
if req.return_logprob and req.normalized_prompt_logprob is None:
|
if req.return_logprob and req.normalized_prompt_logprob is None:
|
||||||
# Need at least two tokens to compute normalized logprob
|
# Need at least two tokens to compute normalized logprob
|
||||||
if req.extend_input_len < 2:
|
if req.extend_input_len < 2:
|
||||||
@@ -409,11 +441,36 @@ class ModelTpServer:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Add this request to the running batch
|
# Add this request to the running batch
|
||||||
can_run_list.append(req)
|
if (
|
||||||
new_batch_total_tokens += (
|
new_batch_input_tokens + req.extend_input_len
|
||||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
<= self.chunked_prefill_size
|
||||||
)
|
or (
|
||||||
new_batch_input_tokens += req.extend_input_len
|
req.return_logprob and req.normalized_prompt_logprob is None
|
||||||
|
)
|
||||||
|
):
|
||||||
|
can_run_list.append(req)
|
||||||
|
new_batch_total_tokens += (
|
||||||
|
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||||
|
)
|
||||||
|
new_batch_input_tokens += req.extend_input_len
|
||||||
|
else:
|
||||||
|
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
||||||
|
|
||||||
|
if trunc_len <= 0:
|
||||||
|
# Undo locking
|
||||||
|
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
|
available_size += delta
|
||||||
|
break
|
||||||
|
|
||||||
|
req.extend_input_len = trunc_len
|
||||||
|
req.input_ids = req.input_ids[
|
||||||
|
: len(req.prefix_indices) + req.extend_input_len
|
||||||
|
]
|
||||||
|
can_run_list.append(req)
|
||||||
|
self.current_inflight_req = req
|
||||||
|
new_batch_input_tokens += req.extend_input_len
|
||||||
|
new_batch_total_tokens += req.extend_input_len
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -440,7 +497,7 @@ class ModelTpServer:
|
|||||||
f"#cached-token: {hit_tokens}, "
|
f"#cached-token: {hit_tokens}, "
|
||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
f"#running-req: {running_bs}, "
|
f"#running-req: {running_bs}, "
|
||||||
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return the new batch
|
# Return the new batch
|
||||||
@@ -450,7 +507,7 @@ class ModelTpServer:
|
|||||||
self.token_to_kv_pool,
|
self.token_to_kv_pool,
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
)
|
)
|
||||||
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
def forward_prefill_batch(self, batch: Batch):
|
def forward_prefill_batch(self, batch: Batch):
|
||||||
@@ -482,9 +539,10 @@ class ModelTpServer:
|
|||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
if req is not self.current_inflight_req:
|
||||||
req.output_ids.append(next_token_ids[i])
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.check_finished()
|
req.output_ids.append(next_token_ids[i])
|
||||||
|
req.check_finished()
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||||
@@ -545,7 +603,7 @@ class ModelTpServer:
|
|||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
token_ids=tuple(req.input_ids),
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req_pool_indices_cpu[i],
|
||||||
del_in_memory_pool=False,
|
del_in_memory_pool=False,
|
||||||
@@ -553,6 +611,10 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
||||||
|
|
||||||
|
if req is self.current_inflight_req:
|
||||||
|
# inflight request would get a new req idx
|
||||||
|
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
||||||
|
|
||||||
def forward_decode_batch(self, batch: Batch):
|
def forward_decode_batch(self, batch: Batch):
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory
|
||||||
if not batch.check_decode_mem():
|
if not batch.check_decode_mem():
|
||||||
@@ -566,7 +628,7 @@ class ModelTpServer:
|
|||||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||||
)
|
)
|
||||||
self.forward_queue.extend(retracted_reqs)
|
self.waiting_queue.extend(retracted_reqs)
|
||||||
else:
|
else:
|
||||||
self.new_token_ratio = max(
|
self.new_token_ratio = max(
|
||||||
self.new_token_ratio - self.new_token_ratio_decay,
|
self.new_token_ratio - self.new_token_ratio_decay,
|
||||||
@@ -576,7 +638,7 @@ class ModelTpServer:
|
|||||||
if not self.disable_regex_jump_forward:
|
if not self.disable_regex_jump_forward:
|
||||||
# Check for jump-forward
|
# Check for jump-forward
|
||||||
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
||||||
self.forward_queue.extend(jump_forward_reqs)
|
self.waiting_queue.extend(jump_forward_reqs)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -711,8 +773,18 @@ class ModelTpServer:
|
|||||||
else:
|
else:
|
||||||
batch.reqs = []
|
batch.reqs = []
|
||||||
|
|
||||||
|
def filter_out_inflight(self, batch: Batch):
|
||||||
|
# TODO(lsyin): reduce the overhead, make a special version for this
|
||||||
|
if self.current_inflight_req is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
unfinished_indices = list(range(len(batch.reqs)))
|
||||||
|
unfinished_indices.remove(batch.reqs.index(self.current_inflight_req))
|
||||||
|
|
||||||
|
batch.filter_batch(unfinished_indices)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
if len(self.forward_queue) == 0 and (
|
if len(self.waiting_queue) == 0 and (
|
||||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||||
):
|
):
|
||||||
self.tree_cache.reset()
|
self.tree_cache.reset()
|
||||||
@@ -725,20 +797,20 @@ class ModelTpServer:
|
|||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Cache not flushed because there are pending requests. "
|
f"Cache not flushed because there are pending requests. "
|
||||||
f"#queue-req: {len(self.forward_queue)}, "
|
f"#queue-req: {len(self.waiting_queue)}, "
|
||||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def abort_request(self, recv_req):
|
def abort_request(self, recv_req):
|
||||||
# Delete requests in the waiting queue
|
# Delete requests in the waiting queue
|
||||||
to_del = None
|
to_del = None
|
||||||
for i, req in enumerate(self.forward_queue):
|
for i, req in enumerate(self.waiting_queue):
|
||||||
if req.rid == recv_req.rid:
|
if req.rid == recv_req.rid:
|
||||||
to_del = i
|
to_del = i
|
||||||
break
|
break
|
||||||
|
|
||||||
if to_del is not None:
|
if to_del is not None:
|
||||||
del self.forward_queue[to_del]
|
del self.waiting_queue[to_del]
|
||||||
|
|
||||||
# Delete requests in the running batch
|
# Delete requests in the running batch
|
||||||
if self.running_batch:
|
if self.running_batch:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class ReqToTokenPool:
|
|||||||
|
|
||||||
return select_index
|
return select_index
|
||||||
|
|
||||||
def free(self, free_index: int):
|
def free(self, free_index):
|
||||||
self.mem_state[free_index] = True
|
self.mem_state[free_index] = True
|
||||||
if isinstance(free_index, (int,)):
|
if isinstance(free_index, (int,)):
|
||||||
self.can_use_mem_size += 1
|
self.can_use_mem_size += 1
|
||||||
|
|||||||
@@ -175,6 +175,39 @@ def _set_torch_compile_config():
|
|||||||
torch._dynamo.config.accumulated_cache_size_limit = 256
|
torch._dynamo.config.accumulated_cache_size_limit = 256
|
||||||
|
|
||||||
|
|
||||||
|
def set_envs_and_config(server_args: ServerArgs):
|
||||||
|
# Set global environments
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||||
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
|
# Set ulimit
|
||||||
|
set_ulimit()
|
||||||
|
|
||||||
|
# Enable show time cost for debugging
|
||||||
|
if server_args.show_time_cost:
|
||||||
|
enable_show_time_cost()
|
||||||
|
|
||||||
|
# Disable disk cache
|
||||||
|
if server_args.disable_disk_cache:
|
||||||
|
disable_cache()
|
||||||
|
|
||||||
|
# Fix triton bugs
|
||||||
|
if server_args.tp_size * server_args.dp_size > 1:
|
||||||
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||||
|
maybe_set_triton_cache_manager()
|
||||||
|
|
||||||
|
# Set torch compile config
|
||||||
|
if server_args.enable_torch_compile:
|
||||||
|
_set_torch_compile_config()
|
||||||
|
|
||||||
|
# Set global chat template
|
||||||
|
if server_args.chat_template:
|
||||||
|
# TODO: replace this with huggingface transformers template
|
||||||
|
load_chat_template_for_openai_api(server_args.chat_template)
|
||||||
|
|
||||||
|
|
||||||
def launch_server(
|
def launch_server(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
model_overide_args: Optional[dict] = None,
|
model_overide_args: Optional[dict] = None,
|
||||||
@@ -190,16 +223,6 @@ def launch_server(
|
|||||||
format="%(message)s",
|
format="%(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set global environments
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
||||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
|
||||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
|
||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
|
||||||
set_ulimit()
|
|
||||||
if server_args.show_time_cost:
|
|
||||||
enable_show_time_cost()
|
|
||||||
if server_args.disable_disk_cache:
|
|
||||||
disable_cache()
|
|
||||||
if not server_args.disable_flashinfer:
|
if not server_args.disable_flashinfer:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
@@ -208,14 +231,8 @@ def launch_server(
|
|||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
)
|
)
|
||||||
if server_args.tp_size * server_args.dp_size > 1:
|
|
||||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
set_envs_and_config(server_args)
|
||||||
maybe_set_triton_cache_manager()
|
|
||||||
if server_args.chat_template:
|
|
||||||
# TODO: replace this with huggingface transformers template
|
|
||||||
load_chat_template_for_openai_api(server_args.chat_template)
|
|
||||||
if server_args.enable_torch_compile:
|
|
||||||
_set_torch_compile_config()
|
|
||||||
|
|
||||||
# Allocate ports
|
# Allocate ports
|
||||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||||
|
|||||||
@@ -65,6 +65,9 @@ class ServerArgs:
|
|||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
load_balance_method: str = "round_robin"
|
load_balance_method: str = "round_robin"
|
||||||
|
|
||||||
|
# Chunked Prefill
|
||||||
|
chunked_prefill_size: Optional[int] = None
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_flashinfer: bool = False
|
disable_flashinfer: bool = False
|
||||||
disable_flashinfer_sampling: bool = False
|
disable_flashinfer_sampling: bool = False
|
||||||
@@ -83,6 +86,8 @@ class ServerArgs:
|
|||||||
node_rank: Optional[int] = None
|
node_rank: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if self.chunked_prefill_size is None:
|
||||||
|
self.chunked_prefill_size = int(10**9)
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
if self.mem_fraction_static is None:
|
if self.mem_fraction_static is None:
|
||||||
@@ -223,7 +228,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-num-reqs",
|
"--max-num-reqs",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=ServerArgs.max_num_reqs,
|
||||||
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -311,10 +316,18 @@ class ServerArgs:
|
|||||||
help="The nccl init address of multi-node server.",
|
help="The nccl init address of multi-node server.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nnodes", type=int, default=1, help="The number of nodes."
|
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
|
||||||
)
|
)
|
||||||
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
||||||
|
|
||||||
|
# Chunked prefill
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.chunked_prefill_size,
|
||||||
|
help="The size of the chunked prefill.",
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-flashinfer",
|
"--disable-flashinfer",
|
||||||
@@ -393,6 +406,10 @@ class ServerArgs:
|
|||||||
self.dp_size > 1 and self.node_rank is not None
|
self.dp_size > 1 and self.node_rank is not None
|
||||||
), "multi-node data parallel is not supported"
|
), "multi-node data parallel is not supported"
|
||||||
|
|
||||||
|
assert not (
|
||||||
|
self.chunked_prefill_size is not None and self.disable_radix_cache
|
||||||
|
), "chunked prefill is not supported with radix cache disabled currently"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class PortArgs:
|
class PortArgs:
|
||||||
|
|||||||
Reference in New Issue
Block a user