Clean up event loop (#1586)
This commit is contained in:
@@ -228,20 +228,14 @@ class Scheduler:
|
|||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def event_loop(self):
|
def event_loop(self):
|
||||||
while True:
|
while True:
|
||||||
# Receive requests
|
recv_reqs = self.recv_requests()
|
||||||
if self.tp_rank == 0:
|
self.process_input_requests(recv_reqs)
|
||||||
recv_reqs = self.recv_requests_from_zmq()
|
|
||||||
else:
|
|
||||||
recv_reqs = None
|
|
||||||
|
|
||||||
# Process requests
|
# Run one step
|
||||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
self.run_step()
|
||||||
self.process_requests(recv_reqs)
|
|
||||||
|
|
||||||
# Forward
|
|
||||||
self.forward_step()
|
|
||||||
|
|
||||||
# Send results
|
# Send results
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
@@ -249,19 +243,23 @@ class Scheduler:
|
|||||||
self.send_to_detokenizer.send_pyobj(obj)
|
self.send_to_detokenizer.send_pyobj(obj)
|
||||||
self.out_pyobjs = []
|
self.out_pyobjs = []
|
||||||
|
|
||||||
def recv_requests_from_zmq(self):
|
def recv_requests(self):
|
||||||
recv_reqs = []
|
if self.tp_rank == 0:
|
||||||
|
recv_reqs = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||||
except zmq.ZMQError:
|
except zmq.ZMQError:
|
||||||
break
|
break
|
||||||
recv_reqs.append(recv_req)
|
recv_reqs.append(recv_req)
|
||||||
|
else:
|
||||||
|
recv_reqs = None
|
||||||
|
|
||||||
|
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||||
return recv_reqs
|
return recv_reqs
|
||||||
|
|
||||||
def process_requests(self, recv_reqs: List):
|
def process_input_requests(self, recv_reqs: List):
|
||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
@@ -279,83 +277,6 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid request: {recv_req}")
|
raise ValueError(f"Invalid request: {recv_req}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def forward_step(self):
|
|
||||||
if (
|
|
||||||
self.batch_is_full or len(self.waiting_queue) == 0
|
|
||||||
) and self.current_inflight_req is None:
|
|
||||||
new_batch = None
|
|
||||||
else:
|
|
||||||
new_batch = self.get_new_prefill_batch()
|
|
||||||
|
|
||||||
if new_batch is not None:
|
|
||||||
# Run a new prefill batch
|
|
||||||
self.forward_prefill_batch(new_batch)
|
|
||||||
|
|
||||||
if not new_batch.is_empty():
|
|
||||||
if self.running_batch is None:
|
|
||||||
self.running_batch = new_batch
|
|
||||||
else:
|
|
||||||
self.running_batch.merge_batch(new_batch)
|
|
||||||
else:
|
|
||||||
# Run a decode batch
|
|
||||||
if self.running_batch is not None:
|
|
||||||
# Run a few decode batches continuously for reducing overhead
|
|
||||||
for _ in range(global_config.num_continue_decode_steps):
|
|
||||||
self.num_generated_tokens += len(self.running_batch.reqs)
|
|
||||||
self.forward_decode_batch(self.running_batch)
|
|
||||||
|
|
||||||
# Print stats
|
|
||||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
|
||||||
self.print_decode_stats()
|
|
||||||
|
|
||||||
if self.running_batch.is_empty():
|
|
||||||
self.running_batch = None
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.out_pyobjs and self.running_batch.has_stream:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
self.check_memory()
|
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
|
||||||
|
|
||||||
def print_decode_stats(self):
|
|
||||||
num_used = self.max_total_num_tokens - (
|
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
|
||||||
)
|
|
||||||
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
|
||||||
self.num_generated_tokens = 0
|
|
||||||
self.last_stats_tic = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"Decode batch. "
|
|
||||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
|
||||||
f"#token: {num_used}, "
|
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
|
||||||
f"gen throughput (token/s): {throughput:.2f}, "
|
|
||||||
f"#queue-req: {len(self.waiting_queue)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_memory(self):
|
|
||||||
available_size = (
|
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
|
||||||
)
|
|
||||||
if available_size != self.max_total_num_tokens:
|
|
||||||
warnings.warn(
|
|
||||||
"Warning: "
|
|
||||||
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
|
||||||
"KV cache pool leak detected!"
|
|
||||||
)
|
|
||||||
exit(1) if crash_on_warning else None
|
|
||||||
|
|
||||||
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
|
||||||
warnings.warn(
|
|
||||||
"Warning: "
|
|
||||||
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
|
||||||
f"total slots={self.req_to_token_pool.size}\n"
|
|
||||||
"Memory pool leak detected!"
|
|
||||||
)
|
|
||||||
exit(1) if crash_on_warning else None
|
|
||||||
|
|
||||||
def handle_generate_request(
|
def handle_generate_request(
|
||||||
self,
|
self,
|
||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
@@ -445,7 +366,88 @@ class Scheduler:
|
|||||||
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
def run_step(self):
|
||||||
|
new_batch = self.get_new_batch_prefill()
|
||||||
|
|
||||||
|
if new_batch is not None:
|
||||||
|
# Run a new prefill batch
|
||||||
|
result = self.run_batch(new_batch)
|
||||||
|
self.process_batch_result(new_batch, result)
|
||||||
|
|
||||||
|
if not new_batch.is_empty():
|
||||||
|
if self.running_batch is None:
|
||||||
|
self.running_batch = new_batch
|
||||||
|
else:
|
||||||
|
self.running_batch.merge_batch(new_batch)
|
||||||
|
else:
|
||||||
|
# Run a decode batch
|
||||||
|
if self.running_batch is not None:
|
||||||
|
# Run a few decode batches continuously for reducing overhead
|
||||||
|
for _ in range(global_config.num_continue_decode_steps):
|
||||||
|
batch = self.get_new_batch_decode()
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
self.process_batch_result(batch, result)
|
||||||
|
|
||||||
|
# Print stats
|
||||||
|
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||||
|
self.print_decode_stats()
|
||||||
|
|
||||||
|
if self.running_batch.is_empty():
|
||||||
|
self.running_batch = None
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.out_pyobjs and self.running_batch.has_stream:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.check_memory()
|
||||||
|
self.new_token_ratio = global_config.init_new_token_ratio
|
||||||
|
|
||||||
|
def print_decode_stats(self):
|
||||||
|
num_used = self.max_total_num_tokens - (
|
||||||
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||||
|
)
|
||||||
|
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
||||||
|
self.num_generated_tokens = 0
|
||||||
|
self.last_stats_tic = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Decode batch. "
|
||||||
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||||
|
f"#token: {num_used}, "
|
||||||
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
|
f"gen throughput (token/s): {throughput:.2f}, "
|
||||||
|
f"#queue-req: {len(self.waiting_queue)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_memory(self):
|
||||||
|
available_size = (
|
||||||
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||||
|
)
|
||||||
|
if available_size != self.max_total_num_tokens:
|
||||||
|
warnings.warn(
|
||||||
|
"Warning: "
|
||||||
|
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
||||||
|
"KV cache pool leak detected!"
|
||||||
|
)
|
||||||
|
exit(1) if crash_on_warning else None
|
||||||
|
|
||||||
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
||||||
|
warnings.warn(
|
||||||
|
"Warning: "
|
||||||
|
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
||||||
|
f"total slots={self.req_to_token_pool.size}\n"
|
||||||
|
"Memory pool leak detected!"
|
||||||
|
)
|
||||||
|
exit(1) if crash_on_warning else None
|
||||||
|
|
||||||
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
|
# Handle the cases where prefill is not allowed
|
||||||
|
if (
|
||||||
|
self.batch_is_full or len(self.waiting_queue) == 0
|
||||||
|
) and self.current_inflight_req is None:
|
||||||
|
return None
|
||||||
|
|
||||||
running_bs = (
|
running_bs = (
|
||||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||||
)
|
)
|
||||||
@@ -456,8 +458,8 @@ class Scheduler:
|
|||||||
# Get priority queue
|
# Get priority queue
|
||||||
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
prefix_computed = self.policy.calc_priority(self.waiting_queue)
|
||||||
|
|
||||||
|
# Prefill policy
|
||||||
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
||||||
|
|
||||||
adder = PrefillAdder(
|
adder = PrefillAdder(
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.running_batch,
|
self.running_batch,
|
||||||
@@ -517,6 +519,8 @@ class Scheduler:
|
|||||||
if len(can_run_list) == 0:
|
if len(can_run_list) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
if isinstance(self.tree_cache, RadixCache):
|
if isinstance(self.tree_cache, RadixCache):
|
||||||
@@ -544,7 +548,7 @@ class Scheduler:
|
|||||||
f"#cached-token: {adder.log_hit_tokens}, "
|
f"#cached-token: {adder.log_hit_tokens}, "
|
||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -555,41 +559,97 @@ class Scheduler:
|
|||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"#running-req: {running_bs}, "
|
f"#running-req: {running_bs}, "
|
||||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return the new batch
|
# Create a new batch
|
||||||
new_batch = ScheduleBatch.init_new(
|
new_batch = ScheduleBatch.init_new(
|
||||||
can_run_list,
|
can_run_list,
|
||||||
self.req_to_token_pool,
|
self.req_to_token_pool,
|
||||||
self.token_to_kv_pool,
|
self.token_to_kv_pool,
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
)
|
)
|
||||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
||||||
return new_batch
|
|
||||||
|
|
||||||
def forward_prefill_batch(self, batch: ScheduleBatch):
|
|
||||||
# Build batch tensors
|
|
||||||
batch.prepare_for_extend(self.model_config.vocab_size)
|
|
||||||
|
|
||||||
|
# Mixed-style chunked prefill
|
||||||
decoding_reqs = []
|
decoding_reqs = []
|
||||||
if self.is_mixed_chunk and self.running_batch is not None:
|
if self.is_mixed_chunk and self.running_batch is not None:
|
||||||
self.running_batch.prepare_for_decode()
|
self.running_batch.prepare_for_decode()
|
||||||
batch.mix_with_running(self.running_batch)
|
new_batch.mix_with_running(self.running_batch)
|
||||||
decoding_reqs = self.running_batch.reqs
|
decoding_reqs = self.running_batch.reqs
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
|
new_batch.decoding_reqs = decoding_reqs
|
||||||
|
|
||||||
|
return new_batch
|
||||||
|
|
||||||
|
def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
|
||||||
|
batch = self.running_batch
|
||||||
|
|
||||||
|
# Check if decode out of memory
|
||||||
|
if not batch.check_decode_mem():
|
||||||
|
old_ratio = self.new_token_ratio
|
||||||
|
|
||||||
|
retracted_reqs, new_token_ratio = batch.retract_decode()
|
||||||
|
self.new_token_ratio = new_token_ratio
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Decode out of memory happened. "
|
||||||
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||||
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||||
|
)
|
||||||
|
self.waiting_queue.extend(retracted_reqs)
|
||||||
|
else:
|
||||||
|
self.new_token_ratio = max(
|
||||||
|
self.new_token_ratio - self.new_token_ratio_decay,
|
||||||
|
self.min_new_token_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for jump-forward
|
||||||
|
if not self.disable_regex_jump_forward:
|
||||||
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||||
|
self.waiting_queue.extend(jump_forward_reqs)
|
||||||
|
if batch.is_empty():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update batch tensors
|
||||||
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||||
|
batch.prepare_for_decode()
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def run_batch(self, batch: ScheduleBatch):
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
# Forward and sample the next tokens
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||||
if batch.extend_num_tokens != 0:
|
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
else:
|
||||||
next_token_ids
|
logits_output = None
|
||||||
)
|
if self.tokenizer is not None:
|
||||||
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
|
else:
|
||||||
|
next_token_ids = [0] * len(batch.reqs)
|
||||||
|
return logits_output, next_token_ids
|
||||||
|
else: # embedding or reward model
|
||||||
|
assert batch.extend_num_tokens != 0
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||||
|
if batch.forward_mode.is_decode():
|
||||||
|
self.process_batch_result_decode(batch, result)
|
||||||
|
else:
|
||||||
|
self.process_batch_result_prefill(batch, result)
|
||||||
|
|
||||||
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
|
if self.is_generation:
|
||||||
|
logits_output, next_token_ids = result
|
||||||
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
|
next_token_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if logits_output:
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if logits_output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
logits_output.next_token_logprobs = (
|
logits_output.next_token_logprobs = (
|
||||||
@@ -607,16 +667,7 @@ class Scheduler:
|
|||||||
logits_output.normalized_prompt_logprobs.tolist()
|
logits_output.normalized_prompt_logprobs.tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
else:
|
|
||||||
if self.tokenizer is None:
|
|
||||||
next_token_ids = []
|
|
||||||
for req in batch.reqs:
|
|
||||||
next_token_ids.append(
|
|
||||||
next(iter(req.sampling_params.stop_token_ids))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
@@ -634,7 +685,7 @@ class Scheduler:
|
|||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
elif req not in decoding_reqs:
|
elif req not in batch.decoding_reqs:
|
||||||
# To reduce overhead, only cache prefill reqs
|
# To reduce overhead, only cache prefill reqs
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
@@ -646,10 +697,9 @@ class Scheduler:
|
|||||||
logprob_pt += self.add_logprob_return_values(
|
logprob_pt += self.add_logprob_return_values(
|
||||||
i, req, logprob_pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
)
|
)
|
||||||
else:
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
embeddings = result
|
||||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
@@ -671,6 +721,45 @@ class Scheduler:
|
|||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
|
logits_output, next_token_ids = result
|
||||||
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
|
next_token_ids
|
||||||
|
)
|
||||||
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
|
# Move logprobs to cpu
|
||||||
|
if logits_output.next_token_logprobs is not None:
|
||||||
|
next_token_logprobs = logits_output.next_token_logprobs[
|
||||||
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||||
|
next_token_ids,
|
||||||
|
].tolist()
|
||||||
|
|
||||||
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
|
||||||
|
# Check finish condition
|
||||||
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
|
req.output_ids.append(next_token_id)
|
||||||
|
req.check_finished()
|
||||||
|
|
||||||
|
if req.regex_fsm is not None:
|
||||||
|
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||||
|
req.regex_fsm_state, next_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if req.finished():
|
||||||
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
|
||||||
|
if req.return_logprob:
|
||||||
|
req.output_token_logprobs.append(
|
||||||
|
(next_token_logprobs[i], next_token_id)
|
||||||
|
)
|
||||||
|
if req.top_logprobs_num > 0:
|
||||||
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||||
|
|
||||||
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
def add_logprob_return_values(
|
def add_logprob_return_values(
|
||||||
self,
|
self,
|
||||||
i: int,
|
i: int,
|
||||||
@@ -744,80 +833,6 @@ class Scheduler:
|
|||||||
|
|
||||||
return num_input_logprobs
|
return num_input_logprobs
|
||||||
|
|
||||||
def forward_decode_batch(self, batch: ScheduleBatch):
|
|
||||||
# Check if decode out of memory
|
|
||||||
if not batch.check_decode_mem():
|
|
||||||
old_ratio = self.new_token_ratio
|
|
||||||
|
|
||||||
retracted_reqs, new_token_ratio = batch.retract_decode()
|
|
||||||
self.new_token_ratio = new_token_ratio
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Decode out of memory happened. "
|
|
||||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
|
||||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
|
||||||
)
|
|
||||||
self.waiting_queue.extend(retracted_reqs)
|
|
||||||
else:
|
|
||||||
self.new_token_ratio = max(
|
|
||||||
self.new_token_ratio - self.new_token_ratio_decay,
|
|
||||||
self.min_new_token_ratio,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for jump-forward
|
|
||||||
if not self.disable_regex_jump_forward:
|
|
||||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
|
||||||
self.waiting_queue.extend(jump_forward_reqs)
|
|
||||||
if batch.is_empty():
|
|
||||||
return
|
|
||||||
|
|
||||||
# Update batch tensors
|
|
||||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
|
||||||
batch.prepare_for_decode()
|
|
||||||
|
|
||||||
# Forward and sample the next tokens
|
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
|
||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
|
||||||
model_worker_batch
|
|
||||||
)
|
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
|
||||||
next_token_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move logprobs to cpu
|
|
||||||
if logits_output.next_token_logprobs is not None:
|
|
||||||
next_token_logprobs = logits_output.next_token_logprobs[
|
|
||||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
|
||||||
next_token_ids,
|
|
||||||
].tolist()
|
|
||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
|
||||||
|
|
||||||
# Check finish condition
|
|
||||||
has_finished = False
|
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
|
||||||
req.output_ids.append(next_token_id)
|
|
||||||
req.check_finished()
|
|
||||||
|
|
||||||
if req.regex_fsm is not None:
|
|
||||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
|
||||||
req.regex_fsm_state, next_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.finished():
|
|
||||||
self.tree_cache.cache_finished_req(req)
|
|
||||||
has_finished = True
|
|
||||||
|
|
||||||
if req.return_logprob:
|
|
||||||
req.output_token_logprobs.append(
|
|
||||||
(next_token_logprobs[i], next_token_id)
|
|
||||||
)
|
|
||||||
if req.top_logprobs_num > 0:
|
|
||||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
|
||||||
|
|
||||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
@@ -829,7 +844,7 @@ class Scheduler:
|
|||||||
output_read_offsets = []
|
output_read_offsets = []
|
||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_spaces_between_special_tokens = []
|
output_spaces_between_special_tokens = []
|
||||||
else: # for embedding model
|
else: # embedding or reward model
|
||||||
output_embeddings = []
|
output_embeddings = []
|
||||||
unfinished_indices = []
|
unfinished_indices = []
|
||||||
|
|
||||||
@@ -886,7 +901,7 @@ class Scheduler:
|
|||||||
req.normalized_prompt_logprob,
|
req.normalized_prompt_logprob,
|
||||||
)
|
)
|
||||||
output_meta_info.append(meta_info)
|
output_meta_info.append(meta_info)
|
||||||
else: # for embedding model
|
else: # embedding or reward model
|
||||||
output_embeddings.append(req.embedding)
|
output_embeddings.append(req.embedding)
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"prompt_tokens": len(req.origin_input_ids),
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
@@ -909,7 +924,7 @@ class Scheduler:
|
|||||||
output_finished_reason,
|
output_finished_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # for embedding model
|
else: # embedding or reward model
|
||||||
self.out_pyobjs.append(
|
self.out_pyobjs.append(
|
||||||
BatchEmbeddingOut(
|
BatchEmbeddingOut(
|
||||||
output_rids,
|
output_rids,
|
||||||
|
|||||||
Reference in New Issue
Block a user