Flashinfer sample kernel (#617)

This commit is contained in:
Liangsheng Yin
2024-07-17 13:24:43 -07:00
committed by GitHub
parent 4efcc59d4f
commit 3de2f30a27
4 changed files with 17 additions and 30 deletions

View File

@@ -451,7 +451,7 @@ class ModelTpServer:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
@@ -574,7 +574,7 @@ class ModelTpServer:
# Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(output.next_token_logits)
next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu
if output.next_token_logprobs is not None: