Improve code style of sampler (#1168)
This commit is contained in:
@@ -54,7 +54,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import suppress_other_loggers
|
||||
|
||||
|
||||
101
python/sglang/srt/layers/sampler.py
Normal file
101
python/sglang/srt/layers/sampler.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from flashinfer.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
# TODO: move this dict to another place
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Sampler(CustomOp):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(sampling_info.temperatures)
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
if sampling_info.vocab_mask is not None:
|
||||
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
||||
|
||||
logits = sampling_info.penalizer_orchestrator.apply(logits)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
if sampling_info.min_ps.any():
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||
probs, uniform_samples, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
||||
)
|
||||
else:
|
||||
# Here we provide a slower fallback implementation.
|
||||
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
||||
)
|
||||
|
||||
if not torch.all(success):
|
||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
def forward_native():
|
||||
raise NotImplementedError("Native forward is not implemented yet.")
|
||||
|
||||
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
top_ks: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
min_ps: torch.Tensor,
|
||||
):
|
||||
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||
>= top_ks.view(-1, 1)
|
||||
] = 0.0
|
||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Sampling error: {e}")
|
||||
batch_next_token_ids = torch.zeros(
|
||||
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
||||
)
|
||||
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||
return batch_next_token_ids, success
|
||||
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||
return batch_next_token_ids, success
|
||||
@@ -23,7 +23,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -20,22 +20,14 @@ from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from flashinfer.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_group
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
@@ -340,14 +332,6 @@ class ScheduleBatch:
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: List[int] = None
|
||||
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
min_ps: torch.Tensor = None
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
logit_bias: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||
return_logprob = any(req.return_logprob for req in reqs)
|
||||
@@ -395,46 +379,6 @@ class ScheduleBatch:
|
||||
|
||||
return out_cache_loc
|
||||
|
||||
def batch_sampling_params(self, vocab_size):
|
||||
device = "cuda"
|
||||
bs, reqs = self.batch_size(), self.reqs
|
||||
self.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
).view(-1, 1)
|
||||
self.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
self.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
)
|
||||
self.min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge()} cases as well.
|
||||
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=self,
|
||||
device=device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
},
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
self.logit_bias = None
|
||||
|
||||
def prepare_for_extend(self, vocab_size: int):
|
||||
bs = self.batch_size()
|
||||
reqs = self.reqs
|
||||
@@ -475,7 +419,7 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
||||
|
||||
self.batch_sampling_params(vocab_size)
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||
@@ -684,6 +628,8 @@ class ScheduleBatch:
|
||||
self.req_pool_indices, self.seq_lens - 1
|
||||
] = self.out_cache_loc
|
||||
|
||||
self.sampling_info.update_regex_vocab_mask(self)
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int]):
|
||||
if unfinished_indices is None or len(unfinished_indices) == 0:
|
||||
# Filter out all requests
|
||||
@@ -704,24 +650,13 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
"logit_bias",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
if self_val is not None: # logit_bias can be None
|
||||
setattr(self, item, self_val[new_indices])
|
||||
self.sampling_info.filter(unfinished_indices, new_indices)
|
||||
|
||||
def merge(self, other: "ScheduleBatch"):
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
self.sampling_info.merge(other.sampling_info)
|
||||
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
@@ -736,125 +671,11 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
# logit_bias can be None
|
||||
if self.logit_bias is not None or other.logit_bias is not None:
|
||||
vocab_size = (
|
||||
self.logit_bias.shape[1]
|
||||
if self.logit_bias is not None
|
||||
else other.logit_bias.shape[1]
|
||||
)
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
if other.logit_bias is None:
|
||||
other.logit_bias = torch.zeros(
|
||||
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
|
||||
def sample(self, logits: torch.Tensor):
|
||||
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(self.temperatures)
|
||||
if self.logit_bias is not None:
|
||||
logits.add_(self.logit_bias)
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
|
||||
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
||||
if has_regex:
|
||||
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
allowed_mask.zero_()
|
||||
allowed_mask[
|
||||
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
||||
] = 1
|
||||
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
||||
sampler = Sampler()
|
||||
|
||||
logits = self.penalizer_orchestrator.apply(logits)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
if self.min_ps.any():
|
||||
probs = top_k_renorm_prob(probs, self.top_ks)
|
||||
probs = top_p_renorm_prob(probs, self.top_ps)
|
||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.min_ps
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, self.top_ks, self.top_ps
|
||||
)
|
||||
else:
|
||||
# Here we provide a slower fallback implementation.
|
||||
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs, self.top_ks, self.top_ps, self.min_ps
|
||||
)
|
||||
|
||||
if not torch.all(success):
|
||||
logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
|
||||
if has_regex:
|
||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||
)
|
||||
|
||||
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
||||
batch_next_token_ids = sampler(logits, self.sampling_info)
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
top_ks: torch.Tensor,
|
||||
top_ps: torch.Tensor,
|
||||
min_ps: torch.Tensor,
|
||||
):
|
||||
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
min_p_thresholds = probs_sort[:, 0] * min_ps
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
||||
>= top_ks.view(-1, 1)
|
||||
] = 0.0
|
||||
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Sampling error: {e}")
|
||||
batch_next_token_ids = torch.zeros(
|
||||
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
||||
)
|
||||
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||
return batch_next_token_ids, success
|
||||
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
||||
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
||||
return batch_next_token_ids, success
|
||||
|
||||
@@ -50,7 +50,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightReqOutput,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -482,6 +482,9 @@ class ModelTpServer:
|
||||
if batch.extend_num_tokens != 0:
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
@@ -514,6 +517,11 @@ class ModelTpServer:
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||
req.regex_fsm_state, next_token_ids[i]
|
||||
)
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
elif req not in decoding_reqs:
|
||||
@@ -642,6 +650,9 @@ class ModelTpServer:
|
||||
# Forward and sample the next tokens
|
||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
@@ -658,6 +669,11 @@ class ModelTpServer:
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||
req.regex_fsm_state, next_token_id
|
||||
)
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
|
||||
@@ -120,9 +120,6 @@ class ModelRunner:
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
self.tp_group = get_tp_group()
|
||||
self.is_multi_node_tp = not all(
|
||||
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
||||
|
||||
136
python/sglang/srt/sampling/sampling_batch_info.py
Normal file
136
python/sglang/srt/sampling/sampling_batch_info.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingBatchInfo:
|
||||
# Basic Info
|
||||
vocab_size: int
|
||||
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
top_ks: torch.Tensor = None
|
||||
min_ps: torch.Tensor = None
|
||||
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: torch.Tensor = None
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
device = "cuda"
|
||||
reqs = batch.reqs
|
||||
ret = cls(vocab_size=vocab_size)
|
||||
|
||||
ret.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
).view(-1, 1)
|
||||
ret.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
ret.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
)
|
||||
ret.min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
||||
)
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge()} cases as well.
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=device,
|
||||
Penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
},
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
ret.logit_bias = None
|
||||
|
||||
ret.update_regex_vocab_mask(batch)
|
||||
|
||||
return ret
|
||||
|
||||
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
||||
bs, reqs = batch.batch_size(), batch.reqs
|
||||
device = "cuda"
|
||||
has_regex = any(req.regex_fsm is not None for req in reqs)
|
||||
|
||||
# Reset the vocab mask
|
||||
self.vocab_mask = None
|
||||
|
||||
if has_regex:
|
||||
for i, req in enumerate(reqs):
|
||||
if req.regex_fsm is not None:
|
||||
if self.vocab_mask is None:
|
||||
self.vocab_mask = torch.zeros(
|
||||
bs, self.vocab_size, dtype=torch.bool, device=device
|
||||
)
|
||||
self.vocab_mask[i][
|
||||
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
||||
] = 1
|
||||
|
||||
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
"logit_bias",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
if self_val is not None: # logit_bias can be None
|
||||
setattr(self, item, self_val[new_indices])
|
||||
|
||||
def merge(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
# logit_bias can be None
|
||||
if self.logit_bias is not None or other.logit_bias is not None:
|
||||
vocab_size = (
|
||||
self.logit_bias.shape[1]
|
||||
if self.logit_bias is not None
|
||||
else other.logit_bias.shape[1]
|
||||
)
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
if other.logit_bias is None:
|
||||
other.logit_bias = torch.zeros(
|
||||
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
Reference in New Issue
Block a user