Add output_ids into ScheduleBatch (#1659)
This commit is contained in:
@@ -232,17 +232,18 @@ def extend(reqs, model_runner):
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
batch.output_ids = input_token_ids
|
||||
batch.prepare_for_decode()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
@@ -252,6 +253,7 @@ def correctness_test(
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
@@ -279,8 +281,9 @@ def correctness_test(
|
||||
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
||||
for _ in range(bench_args.output_len[0] - 1):
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
next_token_ids_list = next_token_ids.tolist()
|
||||
for i in range(len(reqs)):
|
||||
output_ids[i].append(next_token_ids[i])
|
||||
output_ids[i].append(next_token_ids_list[i])
|
||||
|
||||
# Print
|
||||
for i in range(len(reqs)):
|
||||
|
||||
@@ -410,6 +410,8 @@ class ScheduleBatch:
|
||||
seq_lens: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
|
||||
output_ids: torch.Tensor = None
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
@@ -720,19 +722,12 @@ class ScheduleBatch:
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_for_decode(self, input_ids=None):
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = [
|
||||
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
||||
for r in self.reqs
|
||||
]
|
||||
|
||||
self.input_ids = torch.tensor(
|
||||
input_ids, dtype=torch.int32, device=self.seq_lens.device
|
||||
)
|
||||
self.input_ids = self.output_ids
|
||||
self.seq_lens.add_(1)
|
||||
self.output_ids = None
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
@@ -759,6 +754,7 @@ class ScheduleBatch:
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.out_cache_loc = None
|
||||
self.output_ids = self.output_ids[new_indices]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [
|
||||
@@ -783,6 +779,8 @@ class ScheduleBatch:
|
||||
)
|
||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||
self.out_cache_loc = None
|
||||
if self.output_ids is not None:
|
||||
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
||||
if self.return_logprob and other.return_logprob:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
elif self.return_logprob:
|
||||
@@ -838,7 +836,9 @@ class ScheduleBatch:
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
tree_cache=self.tree_cache,
|
||||
forward_mode=self.forward_mode,
|
||||
output_token_ids=self.output_token_ids,
|
||||
output_ids=self.output_ids,
|
||||
sampling_info=self.sampling_info,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -247,7 +247,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
def event_loop_normal(self):
|
||||
self.last_batch = None
|
||||
|
||||
while True:
|
||||
@@ -411,9 +411,10 @@ class Scheduler:
|
||||
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
logger.info(
|
||||
f"Decode batch. "
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
@@ -659,6 +660,7 @@ class Scheduler:
|
||||
)
|
||||
else:
|
||||
next_token_ids = torch.full((batch.batch_size(),), 0)
|
||||
batch.output_ids = next_token_ids
|
||||
ret = logits_output, next_token_ids
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
@@ -753,7 +755,7 @@ class Scheduler:
|
||||
# Inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
self.stream_output(batch)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids = result
|
||||
@@ -793,7 +795,7 @@ class Scheduler:
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
self.stream_output(batch)
|
||||
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||
@@ -872,7 +874,7 @@ class Scheduler:
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||
def stream_output(self, batch: ScheduleBatch):
|
||||
output_rids = []
|
||||
output_meta_info = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
@@ -949,6 +951,9 @@ class Scheduler:
|
||||
}
|
||||
output_meta_info.append(meta_info)
|
||||
|
||||
# Remove finished reqs: update batch tensors
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
if self.is_generation:
|
||||
@@ -976,9 +981,6 @@ class Scheduler:
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished reqs: update batch tensors
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
def flush_cache(self):
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
@@ -1060,7 +1062,7 @@ def run_scheduler_process(
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
||||
pipe_writer.send("ready")
|
||||
scheduler.event_loop()
|
||||
scheduler.event_loop_normal()
|
||||
except Exception:
|
||||
msg = get_exception_traceback()
|
||||
logger.error(msg)
|
||||
|
||||
Reference in New Issue
Block a user