Flashinfer sample kernel (#617)
This commit is contained in:
@@ -156,14 +156,14 @@ def extend(reqs, model_runner):
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
||||
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
return next_token_ids, output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
||||
output = model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
return next_token_ids, output.next_token_logits
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
@@ -398,10 +399,10 @@ class Batch:
|
||||
).view(-1, 1)
|
||||
self.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
).view(-1, 1)
|
||||
)
|
||||
self.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
).view(-1, 1)
|
||||
)
|
||||
self.frequency_penalties = torch.tensor(
|
||||
[r.sampling_params.frequency_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
@@ -659,20 +660,17 @@ class Batch:
|
||||
|
||||
# TODO(lmzheng): apply penalty
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(f"Ignore errors in sampling: {e}")
|
||||
sampled_index = torch.ones(
|
||||
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
|
||||
)
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
||||
-1
|
||||
)
|
||||
batch_next_token_probs = torch.gather(
|
||||
probs_sort, dim=1, index=sampled_index
|
||||
).view(-1)
|
||||
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
||||
|
||||
if has_regex:
|
||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||
@@ -682,18 +680,7 @@ class Batch:
|
||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||
)
|
||||
|
||||
return batch_next_token_ids, batch_next_token_probs
|
||||
|
||||
|
||||
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
||||
] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
return probs_sort, probs_idx
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -451,7 +451,7 @@ class ModelTpServer:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
@@ -574,7 +574,7 @@ class ModelTpServer:
|
||||
|
||||
# Forward and sample the next tokens
|
||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
|
||||
@@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
if not server_args.disable_flashinfer:
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
"0.0.8",
|
||||
"0.1.0",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
|
||||
Reference in New Issue
Block a user