Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)
This commit is contained in:
@@ -225,14 +225,16 @@ def extend(reqs, model_runner):
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||
logits_output = model_runner.forward(batch)
|
||||
input_metadata = batch.get_input_metadata()
|
||||
logits_output = model_runner.forward(input_metadata)
|
||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
logits_output = model_runner.forward(batch)
|
||||
input_metadata = batch.get_input_metadata()
|
||||
logits_output = model_runner.forward(input_metadata)
|
||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user