Mixed style of chunked prefill (#1013)

This commit is contained in:
Liangsheng Yin
2024-08-16 02:13:00 -07:00
committed by GitHub
parent 5a261bd055
commit 3694f8f996
14 changed files with 195 additions and 59 deletions

View File

@@ -111,11 +111,14 @@ class PrefillAdder:
rem_total_tokens: int,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
):
self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens
self.can_run_list = []
self.new_inflight_req = None

View File

@@ -329,6 +329,9 @@ class ScheduleBatch:
out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None
# For mixed chunekd prefill
prefix_lens_cpu: List[int] = None
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: List[int] = None
@@ -462,9 +465,33 @@ class ScheduleBatch:
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
self.batch_sampling_params(vocab_size)
def mix_with_running(self, running_batch: "ScheduleBatch"):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
prefix_lens_cpu.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
for req in running_batch.reqs:
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
self.merge(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens
self.prefix_lens_cpu = prefix_lens_cpu
def check_decode_mem(self):
bs = self.batch_size()
if self.token_to_kv_pool.available_size() >= bs:

View File

@@ -174,6 +174,9 @@ class ModelTpServer:
# Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
# Init the FSM cache for constrained generation
if not server_args.skip_tokenizer_init:
@@ -366,11 +369,14 @@ class ModelTpServer:
# Get priority queue
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder(
self.tree_cache,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens,
self.chunked_prefill_size,
num_mixed_running,
)
if self.running_batch is not None:
@@ -416,15 +422,27 @@ class ModelTpServer:
)
else:
tree_cache_hit_rate = 0.0
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
if num_mixed_running > 0:
logger.info(
f"[gpu={self.gpu_id}] Prefill batch"
f"(mixed #running-req: {num_mixed_running}). "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
else:
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
# Return the new batch
new_batch = ScheduleBatch.init_new(
@@ -440,6 +458,13 @@ class ModelTpServer:
# Build batch tensors
batch.prepare_for_extend(self.model_config.vocab_size)
decoding_reqs = []
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
batch.mix_with_running(self.running_batch)
decoding_reqs = self.running_batch.reqs
self.running_batch = None
if self.model_runner.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
@@ -481,7 +506,8 @@ class ModelTpServer:
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
elif req not in decoding_reqs:
# To reduce overhead, only cache prefill reqs
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:

View File

@@ -88,11 +88,11 @@ class InputMetadata:
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
(
(r.image_offset - len(r.prefix_indices))
(r.image_offset - batch.prefix_lens_cpu[i])
if r.image_offset is not None
else 0
)
for r in reqs
for i, r in enumerate(reqs)
]
def compute_positions(self, batch: ScheduleBatch):
@@ -109,8 +109,8 @@ class InputMetadata:
self.positions = torch.tensor(
np.concatenate(
[
np.arange(len(req.prefix_indices), len(req.fill_ids))
for req in batch.reqs
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
for i, req in enumerate(batch.reqs)
],
axis=0,
),
@@ -123,7 +123,7 @@ class InputMetadata:
np.concatenate(
[
np.arange(
len(req.prefix_indices) + position_ids_offsets_cpu[i],
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
len(req.fill_ids) + position_ids_offsets_cpu[i],
)
for i, req in enumerate(batch.reqs)
@@ -141,12 +141,13 @@ class InputMetadata:
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
else:
extend_lens_cpu = [
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
len(r.fill_ids) - batch.prefix_lens_cpu[i]
for i, r in enumerate(batch.reqs)
]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
@classmethod
def from_schedule_batch(
@@ -180,14 +181,8 @@ class InputMetadata:
if forward_mode != ForwardMode.DECODE:
ret.init_multimuldal_info(batch)
prefix_lens = None
if forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
)
if model_runner.server_args.disable_flashinfer:
ret.init_triton_args(batch, prefix_lens)
ret.init_triton_args(batch)
flashinfer_use_ragged = False
if not model_runner.server_args.disable_flashinfer:
@@ -198,30 +193,35 @@ class InputMetadata:
):
flashinfer_use_ragged = True
ret.init_flashinfer_handlers(
model_runner, prefix_lens, flashinfer_use_ragged
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
)
return ret
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
def init_triton_args(self, batch: ScheduleBatch):
"""Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = int(torch.max(self.seq_lens))
self.triton_prefix_lens = prefix_lens
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode == ForwardMode.DECODE:
self.triton_max_extend_len = None
else:
extend_seq_lens = self.seq_lens - prefix_lens
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
def init_flashinfer_handlers(
self,
model_runner,
prefix_lens,
prefix_lens_cpu,
flashinfer_use_ragged,
):
if self.forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
else:
prefix_lens = None
update_flashinfer_indices(
self.forward_mode,
model_runner,

View File

@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1)
# Print warnings here
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
logger.warning(
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
"This combination is an experimental feature and we noticed it can lead to "
"wrong generation results. If you want to use chunked prefill, it is recommended "
"not using `--disable-radix-cache`."
)
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok")

View File

@@ -80,6 +80,7 @@ class ServerArgs:
disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
enable_p2p_check: bool = False
enable_mla: bool = False
@@ -396,6 +397,11 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
help="Enabling mixing prefill and decode in a chunked batch.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",