Use dtype to control generate (#1082)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Liangsheng Yin
2024-08-14 08:58:07 -07:00
committed by GitHub
parent 67c0d832a6
commit a34dd86a7d
12 changed files with 110 additions and 88 deletions

View File

@@ -195,7 +195,7 @@ def extend(reqs, model_runner):
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits, batch