Clean up event loop (#1586)
This commit is contained in:
@@ -228,20 +228,14 @@ class Scheduler:
|
||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||
self.batch_is_full = False
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
while True:
|
||||
# Receive requests
|
||||
if self.tp_rank == 0:
|
||||
recv_reqs = self.recv_requests_from_zmq()
|
||||
else:
|
||||
recv_reqs = None
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
# Process requests
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
self.process_requests(recv_reqs)
|
||||
|
||||
# Forward
|
||||
self.forward_step()
|
||||
# Run one step
|
||||
self.run_step()
|
||||
|
||||
# Send results
|
||||
if self.tp_rank == 0:
|
||||
@@ -249,19 +243,23 @@ class Scheduler:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
self.out_pyobjs = []
|
||||
|
||||
def recv_requests_from_zmq(self):
|
||||
recv_reqs = []
|
||||
def recv_requests(self):
|
||||
if self.tp_rank == 0:
|
||||
recv_reqs = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
return recv_reqs
|
||||
|
||||
def process_requests(self, recv_reqs: List):
|
||||
def process_input_requests(self, recv_reqs: List):
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
@@ -279,83 +277,6 @@ class Scheduler:
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_step(self):
|
||||
if (
|
||||
self.batch_is_full or len(self.waiting_queue) == 0
|
||||
) and self.current_inflight_req is None:
|
||||
new_batch = None
|
||||
else:
|
||||
new_batch = self.get_new_prefill_batch()
|
||||
|
||||
if new_batch is not None:
|
||||
# Run a new prefill batch
|
||||
self.forward_prefill_batch(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = new_batch
|
||||
else:
|
||||
self.running_batch.merge_batch(new_batch)
|
||||
else:
|
||||
# Run a decode batch
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(global_config.num_continue_decode_steps):
|
||||
self.num_generated_tokens += len(self.running_batch.reqs)
|
||||
self.forward_decode_batch(self.running_batch)
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||
self.print_decode_stats()
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.has_stream:
|
||||
break
|
||||
else:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
|
||||
def print_decode_stats(self):
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
logger.info(
|
||||
f"Decode batch. "
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
|
||||
def check_memory(self):
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
if available_size != self.max_total_num_tokens:
|
||||
warnings.warn(
|
||||
"Warning: "
|
||||
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
||||
"KV cache pool leak detected!"
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
|
||||
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
||||
warnings.warn(
|
||||
"Warning: "
|
||||
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
||||
f"total slots={self.req_to_token_pool.size}\n"
|
||||
"Memory pool leak detected!"
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
@@ -445,7 +366,88 @@ class Scheduler:
|
||||
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
||||
def run_step(self):
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
|
||||
if new_batch is not None:
|
||||
# Run a new prefill batch
|
||||
result = self.run_batch(new_batch)
|
||||
self.process_batch_result(new_batch, result)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = new_batch
|
||||
else:
|
||||
self.running_batch.merge_batch(new_batch)
|
||||
else:
|
||||
# Run a decode batch
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(global_config.num_continue_decode_steps):
|
||||
batch = self.get_new_batch_decode()
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||
self.print_decode_stats()
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.has_stream:
|
||||
break
|
||||
else:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
|
||||
def print_decode_stats(self):
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
logger.info(
|
||||
f"Decode batch. "
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
|
||||
def check_memory(self):
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
if available_size != self.max_total_num_tokens:
|
||||
warnings.warn(
|
||||
"Warning: "
|
||||
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
||||
"KV cache pool leak detected!"
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
|
||||
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
||||
warnings.warn(
|
||||
"Warning: "
|
||||
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
||||
f"total slots={self.req_to_token_pool.size}\n"
|
||||
"Memory pool leak detected!"
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
self.batch_is_full or len(self.waiting_queue) == 0
|
||||
) and self.current_inflight_req is None:
|
||||
return None
|
||||
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
)
|
||||
@@ -456,8 +458,8 @@ class Scheduler:
|
||||
# Get priority queue
|
||||
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
||||
|
||||
# Prefill policy
|
||||
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
||||
|
||||
adder = PrefillAdder(
|
||||
self.tree_cache,
|
||||
self.running_batch,
|
||||
@@ -517,6 +519,8 @@ class Scheduler:
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
if isinstance(self.tree_cache, RadixCache):
|
||||
@@ -544,7 +548,7 @@ class Scheduler:
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
||||
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -555,41 +559,97 @@ class Scheduler:
|
||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
||||
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
||||
)
|
||||
|
||||
# Return the new batch
|
||||
# Create a new batch
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
)
|
||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||
return new_batch
|
||||
|
||||
def forward_prefill_batch(self, batch: ScheduleBatch):
|
||||
# Build batch tensors
|
||||
batch.prepare_for_extend(self.model_config.vocab_size)
|
||||
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
decoding_reqs = []
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
self.running_batch.prepare_for_decode()
|
||||
batch.mix_with_running(self.running_batch)
|
||||
new_batch.mix_with_running(self.running_batch)
|
||||
decoding_reqs = self.running_batch.reqs
|
||||
self.running_batch = None
|
||||
new_batch.decoding_reqs = decoding_reqs
|
||||
|
||||
return new_batch
|
||||
|
||||
def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
|
||||
batch = self.running_batch
|
||||
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
old_ratio = self.new_token_ratio
|
||||
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||
self.new_token_ratio = new_token_ratio
|
||||
|
||||
logger.info(
|
||||
"Decode out of memory happened. "
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
)
|
||||
self.waiting_queue.extend(retracted_reqs)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_decay,
|
||||
self.min_new_token_ratio,
|
||||
)
|
||||
|
||||
# Check for jump-forward
|
||||
if not self.disable_regex_jump_forward:
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return None
|
||||
|
||||
# Update batch tensors
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
batch.prepare_for_decode()
|
||||
return batch
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
if self.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
else:
|
||||
logits_output = None
|
||||
if self.tokenizer is not None:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
else:
|
||||
next_token_ids = [0] * len(batch.reqs)
|
||||
return logits_output, next_token_ids
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
return embeddings
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
else:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids = result
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
|
||||
if logits_output:
|
||||
# Move logprobs to cpu
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
@@ -607,16 +667,7 @@ class Scheduler:
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
if self.tokenizer is None:
|
||||
next_token_ids = []
|
||||
for req in batch.reqs:
|
||||
next_token_ids.append(
|
||||
next(iter(req.sampling_params.stop_token_ids))
|
||||
)
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
logprob_pt = 0
|
||||
@@ -634,7 +685,7 @@ class Scheduler:
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
elif req not in decoding_reqs:
|
||||
elif req not in batch.decoding_reqs:
|
||||
# To reduce overhead, only cache prefill reqs
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
@@ -646,10 +697,9 @@ class Scheduler:
|
||||
logprob_pt += self.add_logprob_return_values(
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
)
|
||||
else:
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
embeddings = result
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
@@ -671,6 +721,45 @@ class Scheduler:
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids = result
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
# Check finish condition
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||
req.regex_fsm_state, next_token_id
|
||||
)
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs.append(
|
||||
(next_token_logprobs[i], next_token_id)
|
||||
)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def add_logprob_return_values(
|
||||
self,
|
||||
i: int,
|
||||
@@ -744,80 +833,6 @@ class Scheduler:
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
def forward_decode_batch(self, batch: ScheduleBatch):
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
old_ratio = self.new_token_ratio
|
||||
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||
self.new_token_ratio = new_token_ratio
|
||||
|
||||
logger.info(
|
||||
"Decode out of memory happened. "
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
)
|
||||
self.waiting_queue.extend(retracted_reqs)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_decay,
|
||||
self.min_new_token_ratio,
|
||||
)
|
||||
|
||||
# Check for jump-forward
|
||||
if not self.disable_regex_jump_forward:
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
|
||||
# Update batch tensors
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward and sample the next tokens
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
# Check finish condition
|
||||
has_finished = False
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||
req.regex_fsm_state, next_token_id
|
||||
)
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
has_finished = True
|
||||
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs.append(
|
||||
(next_token_logprobs[i], next_token_id)
|
||||
)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||
output_rids = []
|
||||
output_meta_info = []
|
||||
@@ -829,7 +844,7 @@ class Scheduler:
|
||||
output_read_offsets = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
else: # for embedding model
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
unfinished_indices = []
|
||||
|
||||
@@ -886,7 +901,7 @@ class Scheduler:
|
||||
req.normalized_prompt_logprob,
|
||||
)
|
||||
output_meta_info.append(meta_info)
|
||||
else: # for embedding model
|
||||
else: # embedding or reward model
|
||||
output_embeddings.append(req.embedding)
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
@@ -909,7 +924,7 @@ class Scheduler:
|
||||
output_finished_reason,
|
||||
)
|
||||
)
|
||||
else: # for embedding model
|
||||
else: # embedding or reward model
|
||||
self.out_pyobjs.append(
|
||||
BatchEmbeddingOut(
|
||||
output_rids,
|
||||
|
||||
Reference in New Issue
Block a user