Chunked prefill (#800)
This commit is contained in:
@@ -38,24 +38,24 @@ class ScheduleHeuristic:
|
||||
self.max_total_num_tokens = max_total_num_tokens
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def get_priority_queue(self, forward_queue):
|
||||
def get_priority_queue(self, waiting_queue):
|
||||
if self.schedule_heuristic == "lpm":
|
||||
# longest prefix match
|
||||
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
|
||||
return forward_queue
|
||||
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
||||
return waiting_queue
|
||||
elif self.schedule_heuristic == "fcfs":
|
||||
# first come first serve
|
||||
return forward_queue
|
||||
return waiting_queue
|
||||
elif self.schedule_heuristic == "lof":
|
||||
# longest output first
|
||||
forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
||||
return forward_queue
|
||||
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
||||
return waiting_queue
|
||||
elif self.schedule_heuristic == "random":
|
||||
random.shuffle(forward_queue)
|
||||
return forward_queue
|
||||
random.shuffle(waiting_queue)
|
||||
return waiting_queue
|
||||
elif self.schedule_heuristic == "dfs-weight":
|
||||
last_node_to_reqs = defaultdict(list)
|
||||
for req in forward_queue:
|
||||
for req in waiting_queue:
|
||||
last_node_to_reqs[req.last_node].append(req)
|
||||
|
||||
node_to_weight = defaultdict(int)
|
||||
@@ -67,7 +67,7 @@ class ScheduleHeuristic:
|
||||
self.get_dfs_priority(
|
||||
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
||||
)
|
||||
assert len(q) == len(forward_queue)
|
||||
assert len(q) == len(waiting_queue)
|
||||
return q
|
||||
else:
|
||||
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
||||
|
||||
@@ -77,6 +77,10 @@ class ModelTpServer:
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
|
||||
# Chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
self.current_inflight_req = None
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
@@ -157,7 +161,7 @@ class ModelTpServer:
|
||||
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||
|
||||
# Init running status
|
||||
self.forward_queue: List[Req] = []
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.running_batch: Batch = None
|
||||
self.out_pyobjs = []
|
||||
self.decode_forward_ct = 0
|
||||
@@ -220,6 +224,7 @@ class ModelTpServer:
|
||||
# Run a new prefill batch
|
||||
self.forward_prefill_batch(new_batch)
|
||||
self.cache_filled_batch(new_batch)
|
||||
self.filter_out_inflight(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
@@ -261,7 +266,7 @@ class ModelTpServer:
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
f"#queue-req: {len(self.forward_queue)}"
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
|
||||
def check_memory(self):
|
||||
@@ -328,9 +333,10 @@ class ModelTpServer:
|
||||
),
|
||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||
)
|
||||
self.forward_queue.append(req)
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||
# TODO(lsyin): organize this function
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
)
|
||||
@@ -338,7 +344,7 @@ class ModelTpServer:
|
||||
return
|
||||
|
||||
# Compute matched prefix length
|
||||
for req in self.forward_queue:
|
||||
for req in self.waiting_queue:
|
||||
req.input_ids = req.origin_input_ids + req.output_ids
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||
if req.return_logprob:
|
||||
@@ -348,7 +354,7 @@ class ModelTpServer:
|
||||
req.last_node = last_node
|
||||
|
||||
# Get priority queue
|
||||
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
||||
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
||||
|
||||
# Add requests if there is available space
|
||||
can_run_list = []
|
||||
@@ -367,7 +373,33 @@ class ModelTpServer:
|
||||
]
|
||||
)
|
||||
|
||||
for req in self.forward_queue:
|
||||
# Handle the current inflight request
|
||||
take_inflight = 0
|
||||
if self.current_inflight_req:
|
||||
take_inflight = 1
|
||||
r = self.current_inflight_req
|
||||
r.input_ids = r.origin_input_ids + r.output_ids
|
||||
truncated = (
|
||||
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
||||
)
|
||||
r.extend_input_len = min(
|
||||
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
|
||||
)
|
||||
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
||||
can_run_list.append(r)
|
||||
|
||||
if not truncated:
|
||||
# Finish inflight
|
||||
self.current_inflight_req = None
|
||||
new_batch_total_tokens += (
|
||||
r.extend_input_len + r.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += r.extend_input_len
|
||||
else:
|
||||
new_batch_total_tokens += r.extend_input_len
|
||||
new_batch_input_tokens += r.extend_input_len
|
||||
|
||||
for req in self.waiting_queue:
|
||||
if req.return_logprob and req.normalized_prompt_logprob is None:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
if req.extend_input_len < 2:
|
||||
@@ -409,11 +441,36 @@ class ModelTpServer:
|
||||
break
|
||||
else:
|
||||
# Add this request to the running batch
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
if (
|
||||
new_batch_input_tokens + req.extend_input_len
|
||||
<= self.chunked_prefill_size
|
||||
or (
|
||||
req.return_logprob and req.normalized_prompt_logprob is None
|
||||
)
|
||||
):
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
else:
|
||||
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
||||
|
||||
if trunc_len <= 0:
|
||||
# Undo locking
|
||||
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
||||
available_size += delta
|
||||
break
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.input_ids = req.input_ids[
|
||||
: len(req.prefix_indices) + req.extend_input_len
|
||||
]
|
||||
can_run_list.append(req)
|
||||
self.current_inflight_req = req
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
new_batch_total_tokens += req.extend_input_len
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
@@ -440,7 +497,7 @@ class ModelTpServer:
|
||||
f"#cached-token: {hit_tokens}, "
|
||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
|
||||
)
|
||||
|
||||
# Return the new batch
|
||||
@@ -450,7 +507,7 @@ class ModelTpServer:
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
)
|
||||
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||
return new_batch
|
||||
|
||||
def forward_prefill_batch(self, batch: Batch):
|
||||
@@ -482,9 +539,10 @@ class ModelTpServer:
|
||||
# Check finish conditions
|
||||
pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
if req is not self.current_inflight_req:
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||
@@ -545,7 +603,7 @@ class ModelTpServer:
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||
token_ids=tuple(req.input_ids),
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
del_in_memory_pool=False,
|
||||
@@ -553,6 +611,10 @@ class ModelTpServer:
|
||||
)
|
||||
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
# inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
||||
|
||||
def forward_decode_batch(self, batch: Batch):
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
@@ -566,7 +628,7 @@ class ModelTpServer:
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
)
|
||||
self.forward_queue.extend(retracted_reqs)
|
||||
self.waiting_queue.extend(retracted_reqs)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_decay,
|
||||
@@ -576,7 +638,7 @@ class ModelTpServer:
|
||||
if not self.disable_regex_jump_forward:
|
||||
# Check for jump-forward
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
||||
self.forward_queue.extend(jump_forward_reqs)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
|
||||
@@ -711,8 +773,18 @@ class ModelTpServer:
|
||||
else:
|
||||
batch.reqs = []
|
||||
|
||||
def filter_out_inflight(self, batch: Batch):
|
||||
# TODO(lsyin): reduce the overhead, make a special version for this
|
||||
if self.current_inflight_req is None:
|
||||
return
|
||||
|
||||
to_remove = batch.reqs.index(self.current_inflight_req)
|
||||
unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
|
||||
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
def flush_cache(self):
|
||||
if len(self.forward_queue) == 0 and (
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
@@ -725,20 +797,20 @@ class ModelTpServer:
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.forward_queue)}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
def abort_request(self, recv_req):
|
||||
# Delete requests in the waiting queue
|
||||
to_del = None
|
||||
for i, req in enumerate(self.forward_queue):
|
||||
for i, req in enumerate(self.waiting_queue):
|
||||
if req.rid == recv_req.rid:
|
||||
to_del = i
|
||||
break
|
||||
|
||||
if to_del is not None:
|
||||
del self.forward_queue[to_del]
|
||||
del self.waiting_queue[to_del]
|
||||
|
||||
# Delete requests in the running batch
|
||||
if self.running_batch:
|
||||
|
||||
@@ -45,7 +45,7 @@ class ReqToTokenPool:
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: int):
|
||||
def free(self, free_index):
|
||||
self.mem_state[free_index] = True
|
||||
if isinstance(free_index, (int,)):
|
||||
self.can_use_mem_size += 1
|
||||
|
||||
@@ -175,6 +175,39 @@ def _set_torch_compile_config():
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 256
|
||||
|
||||
|
||||
def set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# Set ulimit
|
||||
set_ulimit()
|
||||
|
||||
# Enable show time cost for debugging
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
# Disable disk cache
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
|
||||
# Fix triton bugs
|
||||
if server_args.tp_size * server_args.dp_size > 1:
|
||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||
maybe_set_triton_cache_manager()
|
||||
|
||||
# Set torch compile config
|
||||
if server_args.enable_torch_compile:
|
||||
_set_torch_compile_config()
|
||||
|
||||
# Set global chat template
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
|
||||
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
@@ -190,16 +223,6 @@ def launch_server(
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
set_ulimit()
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
if not server_args.disable_flashinfer:
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
@@ -208,14 +231,8 @@ def launch_server(
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
)
|
||||
if server_args.tp_size * server_args.dp_size > 1:
|
||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||
maybe_set_triton_cache_manager()
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
if server_args.enable_torch_compile:
|
||||
_set_torch_compile_config()
|
||||
|
||||
set_envs_and_config(server_args)
|
||||
|
||||
# Allocate ports
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
|
||||
@@ -65,6 +65,9 @@ class ServerArgs:
|
||||
dp_size: int = 1
|
||||
load_balance_method: str = "round_robin"
|
||||
|
||||
# Chunked Prefill
|
||||
chunked_prefill_size: Optional[int] = None
|
||||
|
||||
# Optimization/debug options
|
||||
disable_flashinfer: bool = False
|
||||
disable_flashinfer_sampling: bool = False
|
||||
@@ -83,6 +86,8 @@ class ServerArgs:
|
||||
node_rank: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.chunked_prefill_size is None:
|
||||
self.chunked_prefill_size = 1 << 30
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
if self.mem_fraction_static is None:
|
||||
@@ -223,7 +228,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--max-num-reqs",
|
||||
type=int,
|
||||
default=None,
|
||||
default=ServerArgs.max_num_reqs,
|
||||
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -311,10 +316,18 @@ class ServerArgs:
|
||||
help="The nccl init address of multi-node server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nnodes", type=int, default=1, help="The number of nodes."
|
||||
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
|
||||
)
|
||||
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
||||
|
||||
# Chunked prefill
|
||||
parser.add_argument(
|
||||
"--chunked-prefill-size",
|
||||
type=int,
|
||||
default=ServerArgs.chunked_prefill_size,
|
||||
help="The size of the chunked prefill.",
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--disable-flashinfer",
|
||||
@@ -393,6 +406,10 @@ class ServerArgs:
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
|
||||
assert not (
|
||||
self.chunked_prefill_size < (1 << 30) and self.disable_radix_cache
|
||||
), "chunked prefill is not supported with radix cache disabled currently"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
|
||||
Reference in New Issue
Block a user