Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)

This commit is contained in:
Lianmin Zheng
2024-09-29 20:28:45 -07:00
committed by GitHub
parent 55b974f96f
commit 3f0fe08d37
12 changed files with 142 additions and 157 deletions

View File

@@ -575,8 +575,9 @@ class Scheduler:
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
input_metadata = batch.get_input_metadata()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
batch
input_metadata, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
@@ -640,7 +641,8 @@ class Scheduler:
)
else:
assert batch.extend_num_tokens != 0
embeddings = self.tp_worker.forward_batch_embedding(batch)
input_metadata = batch.get_input_metadata()
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
# Check finish conditions
for i, req in enumerate(batch.reqs):
@@ -769,7 +771,10 @@ class Scheduler:
batch.prepare_for_decode()
# Forward and sample the next tokens
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(batch)
input_metadata = batch.get_input_metadata()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)