Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)

This commit is contained in:
Lianmin Zheng
2024-09-29 20:28:45 -07:00
committed by GitHub
parent 55b974f96f
commit 3f0fe08d37
12 changed files with 142 additions and 157 deletions

View File

@@ -225,14 +225,16 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
logits_output = model_runner.forward(batch)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
logits_output = model_runner.forward(batch)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits