Simplify batch result resolution (#1735)
This commit is contained in:
@@ -149,12 +149,8 @@ class Scheduler:
|
||||
# Launch a tensor parallel worker
|
||||
if self.enable_overlap:
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
else:
|
||||
TpWorkerClass = TpModelWorker
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
|
||||
self.tp_worker = TpWorkerClass(
|
||||
server_args=server_args,
|
||||
@@ -756,9 +752,12 @@ class Scheduler:
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids, bid = result
|
||||
if batch.return_logprob:
|
||||
# Move logprobs to cpu
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
@@ -771,8 +770,7 @@ class Scheduler:
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
|
||||
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
logprob_pt = 0
|
||||
@@ -825,14 +823,16 @@ class Scheduler:
|
||||
logits_output, next_token_ids, bid = result
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
|
||||
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=self.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
|
||||
self.token_to_kv_pool.free_group_begin()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user