diff --git a/examples/usage/json_decode.py b/examples/usage/json_decode.py index dc34d3527..ce8f5ba70 100644 --- a/examples/usage/json_decode.py +++ b/examples/usage/json_decode.py @@ -35,6 +35,9 @@ def character_gen(s, name): name + " is a character in Harry Potter. Please fill in the following information about this character.\n" ) + s += "The constrained regex is:\n" + s += character_regex + "\n" + s += "The JSON output is:\n" s += sgl.gen("json_output", max_tokens=256, regex=character_regex) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index dd86747e3..d9131c87f 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -54,7 +54,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_config import ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.sampling_params import SamplingParams +from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.utils import suppress_other_loggers diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py new file mode 100644 index 000000000..3006e765c --- /dev/null +++ b/python/sglang/srt/layers/sampler.py @@ -0,0 +1,101 @@ +import logging + +import torch +from flashinfer.sampling import ( + min_p_sampling_from_probs, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, +) +from vllm.model_executor.custom_op import CustomOp + +# TODO: move this dict to another place +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + +logger = logging.getLogger(__name__) + + +class Sampler(CustomOp): + def __init__(self): + super().__init__() + + def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): + # Post process logits + logits = logits.contiguous() + logits.div_(sampling_info.temperatures) + if sampling_info.logit_bias is not None: + logits.add_(sampling_info.logit_bias) + + if sampling_info.vocab_mask is not None: + logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) + + logits = sampling_info.penalizer_orchestrator.apply(logits) + + probs = torch.softmax(logits, dim=-1) + + if not global_server_args_dict["disable_flashinfer_sampling"]: + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device + ) + if sampling_info.min_ps.any(): + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids, success = min_p_sampling_from_probs( + probs, uniform_samples, sampling_info.min_ps + ) + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps + ) + else: + # Here we provide a slower fallback implementation. + batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( + probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps + ) + + if not torch.all(success): + logging.warning("Sampling failed, fallback to top_k=1 strategy") + probs = probs.masked_fill(torch.isnan(probs), 0.0) + argmax_ids = torch.argmax(probs, dim=-1) + batch_next_token_ids = torch.where( + success, batch_next_token_ids, argmax_ids + ) + + return batch_next_token_ids + + def forward_native(): + raise NotImplementedError("Native forward is not implemented yet.") + + +def top_k_top_p_min_p_sampling_from_probs_torch( + probs: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, + min_ps: torch.Tensor, +): + """A top-k, top-p and min-p sampling implementation with native pytorch operations.""" + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + min_p_thresholds = probs_sort[:, 0] * min_ps + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + probs_sort[ + torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) + >= top_ks.view(-1, 1) + ] = 0.0 + probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 + probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) + try: + sampled_index = torch.multinomial(probs_sort, num_samples=1) + except RuntimeError as e: + logger.warning(f"Sampling error: {e}") + batch_next_token_ids = torch.zeros( + (probs_sort.shape[0],), dtype=torch.int32, device=probs.device + ) + success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device) + return batch_next_token_ids, success + + batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) + success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device) + return batch_next_token_ids, success diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index dc8224593..56e3d8f79 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -23,7 +23,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason -from sglang.srt.sampling_params import SamplingParams +from sglang.srt.sampling.sampling_params import SamplingParams @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9abce6f9b..88a616832 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -20,22 +20,14 @@ from dataclasses import dataclass from typing import List, Optional, Union import torch -import torch.distributed as dist -from flashinfer.sampling import ( - min_p_sampling_from_probs, - top_k_renorm_prob, - top_k_top_p_sampling_from_probs, - top_p_renorm_prob, -) -from vllm.distributed import get_tensor_model_parallel_group -import sglang.srt.sampling.penaltylib as penaltylib from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -340,14 +332,6 @@ class ScheduleBatch: return_logprob: bool = False top_logprobs_nums: List[int] = None - # Batched sampling params - temperatures: torch.Tensor = None - top_ps: torch.Tensor = None - top_ks: torch.Tensor = None - min_ps: torch.Tensor = None - penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None - logit_bias: torch.Tensor = None - @classmethod def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) @@ -395,46 +379,6 @@ class ScheduleBatch: return out_cache_loc - def batch_sampling_params(self, vocab_size): - device = "cuda" - bs, reqs = self.batch_size(), self.reqs - self.temperatures = torch.tensor( - [r.sampling_params.temperature for r in reqs], - dtype=torch.float, - device=device, - ).view(-1, 1) - self.top_ps = torch.tensor( - [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device - ) - self.top_ks = torch.tensor( - [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device - ) - self.min_ps = torch.tensor( - [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device - ) - - # Each penalizers will do nothing if they evaluate themselves as not required by looking at - # the sampling_params of the requests (See {_is_required()} of each penalizers). So this - # should not add hefty computation overhead other than simple checks. - # - # While we choose not to even create the class instances if they are not required, this - # could add additional complexity to the {ScheduleBatch} class, especially we need to - # handle {filter_batch()} and {merge()} cases as well. - self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( - vocab_size=vocab_size, - batch=self, - device=device, - Penalizers={ - penaltylib.BatchedFrequencyPenalizer, - penaltylib.BatchedMinNewTokensPenalizer, - penaltylib.BatchedPresencePenalizer, - penaltylib.BatchedRepetitionPenalizer, - }, - ) - - # Handle logit bias but only allocate when needed - self.logit_bias = None - def prepare_for_extend(self, vocab_size: int): bs = self.batch_size() reqs = self.reqs @@ -475,7 +419,7 @@ class ScheduleBatch: self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs] - self.batch_sampling_params(vocab_size) + self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) def mix_with_running(self, running_batch: "ScheduleBatch"): # NOTE: prefix_indices is what has been cached, but we don't cache each decode step @@ -684,6 +628,8 @@ class ScheduleBatch: self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc + self.sampling_info.update_regex_vocab_mask(self) + def filter_batch(self, unfinished_indices: List[int]): if unfinished_indices is None or len(unfinished_indices) == 0: # Filter out all requests @@ -704,24 +650,13 @@ class ScheduleBatch: self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) - self.penalizer_orchestrator.filter(unfinished_indices, new_indices) - - for item in [ - "temperatures", - "top_ps", - "top_ks", - "min_ps", - "logit_bias", - ]: - self_val = getattr(self, item, None) - if self_val is not None: # logit_bias can be None - setattr(self, item, self_val[new_indices]) + self.sampling_info.filter(unfinished_indices, new_indices) def merge(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. - self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + self.sampling_info.merge(other.sampling_info) self.reqs.extend(other.reqs) @@ -736,125 +671,11 @@ class ScheduleBatch: self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) - for item in [ - "temperatures", - "top_ps", - "top_ks", - "min_ps", - ]: - self_val = getattr(self, item, None) - other_val = getattr(other, item, None) - setattr(self, item, torch.concat([self_val, other_val])) - - # logit_bias can be None - if self.logit_bias is not None or other.logit_bias is not None: - vocab_size = ( - self.logit_bias.shape[1] - if self.logit_bias is not None - else other.logit_bias.shape[1] - ) - if self.logit_bias is None: - self.logit_bias = torch.zeros( - (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - if other.logit_bias is None: - other.logit_bias = torch.zeros( - (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) - def sample(self, logits: torch.Tensor): - # TODO(lsyin): move this into a part of layer and run with CUDA Graph - # Post process logits - logits = logits.contiguous() - logits.div_(self.temperatures) - if self.logit_bias is not None: - logits.add_(self.logit_bias) + from sglang.srt.layers.sampler import Sampler - has_regex = any(req.regex_fsm is not None for req in self.reqs) - if has_regex: - allowed_mask = torch.empty_like(logits[0], dtype=torch.bool) - for i, req in enumerate(self.reqs): - if req.regex_fsm is not None: - allowed_mask.zero_() - allowed_mask[ - req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens - ] = 1 - logits[i].masked_fill_(~allowed_mask, float("-inf")) + sampler = Sampler() - logits = self.penalizer_orchestrator.apply(logits) - - probs = torch.softmax(logits, dim=-1) - - if not global_server_args_dict["disable_flashinfer_sampling"]: - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device - ) - if self.min_ps.any(): - probs = top_k_renorm_prob(probs, self.top_ks) - probs = top_p_renorm_prob(probs, self.top_ps) - batch_next_token_ids, success = min_p_sampling_from_probs( - probs, uniform_samples, self.min_ps - ) - else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( - probs, uniform_samples, self.top_ks, self.top_ps - ) - else: - # Here we provide a slower fallback implementation. - batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( - probs, self.top_ks, self.top_ps, self.min_ps - ) - - if not torch.all(success): - logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}") - probs = probs.masked_fill(torch.isnan(probs), 0.0) - argmax_ids = torch.argmax(probs, dim=-1) - batch_next_token_ids = torch.where( - success, batch_next_token_ids, argmax_ids - ) - - if has_regex: - batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() - for i, req in enumerate(self.reqs): - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, batch_next_token_ids_cpu[i] - ) - - self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids) + batch_next_token_ids = sampler(logits, self.sampling_info) return batch_next_token_ids - - -def top_k_top_p_min_p_sampling_from_probs_torch( - probs: torch.Tensor, - top_ks: torch.Tensor, - top_ps: torch.Tensor, - min_ps: torch.Tensor, -): - """A top-k, top-p and min-p sampling implementation with native pytorch operations.""" - probs_sort, probs_idx = probs.sort(dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - min_p_thresholds = probs_sort[:, 0] * min_ps - probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 - probs_sort[ - torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) - >= top_ks.view(-1, 1) - ] = 0.0 - probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 - probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) - try: - sampled_index = torch.multinomial(probs_sort, num_samples=1) - except RuntimeError as e: - logger.warning(f"Sampling error: {e}") - batch_next_token_ids = torch.zeros( - (probs_sort.shape[0],), dtype=torch.int32, device=probs.device - ) - success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device) - return batch_next_token_ids, success - - batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) - success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device) - return batch_next_token_ids, success diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ab375a39a..32d1f43d3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -50,7 +50,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightReqOutput, ) from sglang.srt.mm_utils import expand2square, process_anyres_image -from sglang.srt.sampling_params import SamplingParams +from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image from sglang.utils import get_exception_traceback diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8772a4abb..41f908301 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -482,6 +482,9 @@ class ModelTpServer: if batch.extend_num_tokens != 0: output = self.model_runner.forward(batch, ForwardMode.EXTEND) next_token_ids = batch.sample(output.next_token_logits) + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + next_token_ids + ) # Move logprobs to cpu if output.next_token_logprobs is not None: @@ -514,6 +517,11 @@ class ModelTpServer: req.output_ids.append(next_token_ids[i]) req.check_finished() + if req.regex_fsm is not None: + req.regex_fsm_state = req.regex_fsm.get_next_state( + req.regex_fsm_state, next_token_ids[i] + ) + if req.finished(): self.tree_cache.cache_finished_req(req) elif req not in decoding_reqs: @@ -642,6 +650,9 @@ class ModelTpServer: # Forward and sample the next tokens output = self.model_runner.forward(batch, ForwardMode.DECODE) next_token_ids = batch.sample(output.next_token_logits) + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + next_token_ids + ) # Move logprobs to cpu if output.next_token_logprobs is not None: @@ -658,6 +669,11 @@ class ModelTpServer: req.output_ids.append(next_token_id) req.check_finished() + if req.regex_fsm is not None: + req.regex_fsm_state = req.regex_fsm.get_next_state( + req.regex_fsm_state, next_token_id + ) + if req.finished(): self.tree_cache.cache_finished_req(req) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a00a73945..b91191c5d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -120,9 +120,6 @@ class ModelRunner: self.gpu_id, distributed=self.tp_size > 1 ) self.tp_group = get_tp_group() - self.is_multi_node_tp = not all( - in_the_same_node_as(self.tp_group.cpu_group, source_rank=0) - ) if self.tp_size > 1: total_local_gpu_memory = get_available_gpu_memory(self.gpu_id) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py new file mode 100644 index 000000000..bc70a9018 --- /dev/null +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, List + +import torch + +import sglang.srt.sampling.penaltylib as penaltylib + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import ScheduleBatch + + +@dataclasses.dataclass +class SamplingBatchInfo: + # Basic Info + vocab_size: int + + # Batched sampling params + temperatures: torch.Tensor = None + top_ps: torch.Tensor = None + top_ks: torch.Tensor = None + min_ps: torch.Tensor = None + penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None + logit_bias: torch.Tensor = None + vocab_mask: torch.Tensor = None + + @classmethod + def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): + device = "cuda" + reqs = batch.reqs + ret = cls(vocab_size=vocab_size) + + ret.temperatures = torch.tensor( + [r.sampling_params.temperature for r in reqs], + dtype=torch.float, + device=device, + ).view(-1, 1) + ret.top_ps = torch.tensor( + [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device + ) + ret.top_ks = torch.tensor( + [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device + ) + ret.min_ps = torch.tensor( + [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device + ) + + # Each penalizers will do nothing if they evaluate themselves as not required by looking at + # the sampling_params of the requests (See {_is_required()} of each penalizers). So this + # should not add hefty computation overhead other than simple checks. + # + # While we choose not to even create the class instances if they are not required, this + # could add additional complexity to the {ScheduleBatch} class, especially we need to + # handle {filter_batch()} and {merge()} cases as well. + ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( + vocab_size=vocab_size, + batch=batch, + device=device, + Penalizers={ + penaltylib.BatchedFrequencyPenalizer, + penaltylib.BatchedMinNewTokensPenalizer, + penaltylib.BatchedPresencePenalizer, + penaltylib.BatchedRepetitionPenalizer, + }, + ) + + # Handle logit bias but only allocate when needed + ret.logit_bias = None + + ret.update_regex_vocab_mask(batch) + + return ret + + def update_regex_vocab_mask(self, batch: ScheduleBatch): + bs, reqs = batch.batch_size(), batch.reqs + device = "cuda" + has_regex = any(req.regex_fsm is not None for req in reqs) + + # Reset the vocab mask + self.vocab_mask = None + + if has_regex: + for i, req in enumerate(reqs): + if req.regex_fsm is not None: + if self.vocab_mask is None: + self.vocab_mask = torch.zeros( + bs, self.vocab_size, dtype=torch.bool, device=device + ) + self.vocab_mask[i][ + req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens + ] = 1 + + def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor): + self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + + for item in [ + "temperatures", + "top_ps", + "top_ks", + "min_ps", + "logit_bias", + ]: + self_val = getattr(self, item, None) + if self_val is not None: # logit_bias can be None + setattr(self, item, self_val[new_indices]) + + def merge(self, other: "SamplingBatchInfo"): + self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + + for item in [ + "temperatures", + "top_ps", + "top_ks", + "min_ps", + ]: + self_val = getattr(self, item, None) + other_val = getattr(other, item, None) + setattr(self, item, torch.concat([self_val, other_val])) + + # logit_bias can be None + if self.logit_bias is not None or other.logit_bias is not None: + vocab_size = ( + self.logit_bias.shape[1] + if self.logit_bias is not None + else other.logit_bias.shape[1] + ) + if self.logit_bias is None: + self.logit_bias = torch.zeros( + (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" + ) + if other.logit_bias is None: + other.logit_bias = torch.zeros( + (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" + ) + self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py similarity index 100% rename from python/sglang/srt/sampling_params.py rename to python/sglang/srt/sampling/sampling_params.py