Flashinfer sample kernel (#617)
This commit is contained in:
@@ -156,14 +156,14 @@ def extend(reqs, model_runner):
|
|||||||
)
|
)
|
||||||
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
||||||
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
next_token_ids = batch.sample(output.next_token_logits)
|
||||||
return next_token_ids, output.next_token_logits, batch
|
return next_token_ids, output.next_token_logits, batch
|
||||||
|
|
||||||
|
|
||||||
def decode(input_token_ids, batch, model_runner):
|
def decode(input_token_ids, batch, model_runner):
|
||||||
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
||||||
output = model_runner.forward(batch, ForwardMode.DECODE)
|
output = model_runner.forward(batch, ForwardMode.DECODE)
|
||||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
next_token_ids = batch.sample(output.next_token_logits)
|
||||||
return next_token_ids, output.next_token_logits
|
return next_token_ids, output.next_token_logits
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import List, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||||
|
|
||||||
from sglang.srt.constrained import RegexGuide
|
from sglang.srt.constrained import RegexGuide
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||||
@@ -398,10 +399,10 @@ class Batch:
|
|||||||
).view(-1, 1)
|
).view(-1, 1)
|
||||||
self.top_ps = torch.tensor(
|
self.top_ps = torch.tensor(
|
||||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||||
).view(-1, 1)
|
)
|
||||||
self.top_ks = torch.tensor(
|
self.top_ks = torch.tensor(
|
||||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||||
).view(-1, 1)
|
)
|
||||||
self.frequency_penalties = torch.tensor(
|
self.frequency_penalties = torch.tensor(
|
||||||
[r.sampling_params.frequency_penalty for r in reqs],
|
[r.sampling_params.frequency_penalty for r in reqs],
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
@@ -659,20 +660,17 @@ class Batch:
|
|||||||
|
|
||||||
# TODO(lmzheng): apply penalty
|
# TODO(lmzheng): apply penalty
|
||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
|
||||||
try:
|
try:
|
||||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
|
uniform_samples = torch.rand(
|
||||||
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
|
)
|
||||||
|
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
|
||||||
|
probs, uniform_samples, self.top_ks, self.top_ps
|
||||||
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
warnings.warn(f"Ignore errors in sampling: {e}")
|
warnings.warn(f"Ignore errors in sampling: {e}")
|
||||||
sampled_index = torch.ones(
|
batch_next_token_ids = torch.argmax(probs, dim=-1)
|
||||||
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
|
|
||||||
)
|
|
||||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
|
||||||
-1
|
|
||||||
)
|
|
||||||
batch_next_token_probs = torch.gather(
|
|
||||||
probs_sort, dim=1, index=sampled_index
|
|
||||||
).view(-1)
|
|
||||||
|
|
||||||
if has_regex:
|
if has_regex:
|
||||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||||
@@ -682,18 +680,7 @@ class Batch:
|
|||||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
return batch_next_token_ids, batch_next_token_probs
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
||||||
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
|
||||||
probs_sort[
|
|
||||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
|
||||||
] = 0.0
|
|
||||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
|
||||||
return probs_sort, probs_idx
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -451,7 +451,7 @@ class ModelTpServer:
|
|||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
next_token_ids = batch.sample(output.next_token_logits)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if output.next_token_logprobs is not None:
|
||||||
@@ -574,7 +574,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
next_token_ids = batch.sample(output.next_token_logits)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if output.next_token_logprobs is not None:
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
if not server_args.disable_flashinfer:
|
if not server_args.disable_flashinfer:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"0.0.8",
|
"0.1.0",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
Reference in New Issue
Block a user