Flashinfer sample kernel (#617)
This commit is contained in:
@@ -451,7 +451,7 @@ class ModelTpServer:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
@@ -574,7 +574,7 @@ class ModelTpServer:
|
||||
|
||||
# Forward and sample the next tokens
|
||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
|
||||
Reference in New Issue
Block a user