Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)
This commit is contained in:
@@ -575,8 +575,9 @@ class Scheduler:
|
||||
if self.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
input_metadata = batch.get_input_metadata()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
batch
|
||||
input_metadata, batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
@@ -640,7 +641,8 @@ class Scheduler:
|
||||
)
|
||||
else:
|
||||
assert batch.extend_num_tokens != 0
|
||||
embeddings = self.tp_worker.forward_batch_embedding(batch)
|
||||
input_metadata = batch.get_input_metadata()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
@@ -769,7 +771,10 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward and sample the next tokens
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(batch)
|
||||
input_metadata = batch.get_input_metadata()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
input_metadata, batch
|
||||
)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user