Flashinfer sample kernel (#617)
This commit is contained in:
@@ -156,14 +156,14 @@ def extend(reqs, model_runner):
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
||||
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
return next_token_ids, output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
||||
output = model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
return next_token_ids, output.next_token_logits
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user