Simplify batch result resolution (#1735)

This commit is contained in:
Lianmin Zheng
2024-10-20 19:47:14 -07:00
committed by GitHub
parent e12358dc91
commit b121bc03a3
5 changed files with 64 additions and 90 deletions

View File

@@ -149,12 +149,8 @@ class Scheduler:
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else:
TpWorkerClass = TpModelWorker
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.tp_worker = TpWorkerClass(
server_args=server_args,
@@ -756,9 +752,12 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids, bid = result
if batch.return_logprob:
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
@@ -771,8 +770,7 @@ class Scheduler:
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
next_token_ids = next_token_ids.tolist()
# Check finish conditions
logprob_pt = 0
@@ -825,14 +823,16 @@ class Scheduler:
logits_output, next_token_ids, bid = result
self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
self.token_to_kv_pool.free_group_begin()