From a5a134f39f9b032496fa895050e56485d8fe9957 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 2 Sep 2024 16:18:48 -0700 Subject: [PATCH] Fix bugs in sampler with CUDA graph / torch.compile (#1306) --- python/sglang/srt/layers/sampler.py | 44 ++++++++++++++----- .../srt/model_executor/cuda_graph_runner.py | 2 + .../sglang/srt/model_executor/model_runner.py | 2 +- .../srt/sampling/sampling_batch_info.py | 26 +++++------ 4 files changed, 48 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index f56fee828..6cb7d5b55 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,6 +1,6 @@ import dataclasses import logging -from typing import Union +from typing import Tuple, Union import torch from flashinfer.sampling import ( @@ -9,6 +9,7 @@ from flashinfer.sampling import ( top_k_top_p_sampling_from_probs, top_p_renorm_prob, ) +from torch.library import custom_op as torch_custom_op from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -30,6 +31,9 @@ class SampleOutput: class Sampler(CustomOp): def __init__(self): super().__init__() + # FIXME: torch.multinomial has too many bugs + self.forward_native = self.forward_cuda + self.is_torch_compile = False def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): # min-token, presence, frequency @@ -46,16 +50,11 @@ class Sampler(CustomOp): return logits - def _get_probs( - self, - logits: torch.Tensor, - sampling_info: SamplingBatchInfo, - is_torch_compile: bool = False, - ): + def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): # Post process logits logits = logits.contiguous() logits.div_(sampling_info.temperatures) - if is_torch_compile: + if self.is_torch_compile: # FIXME: Temporary workaround for unknown bugs in torch.compile logits.add_(0) @@ -91,7 +90,7 @@ class Sampler(CustomOp): probs, uniform_samples, sampling_info.min_ps ) else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + batch_next_token_ids, success = flashinfer_top_k_top_p( probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps ) else: @@ -110,7 +109,7 @@ class Sampler(CustomOp): if isinstance(logits, LogitsProcessorOutput): logits = logits.next_token_logits - probs = self._get_probs(logits, sampling_info, is_torch_compile=True) + probs = self._get_probs(logits, sampling_info) batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps @@ -119,6 +118,31 @@ class Sampler(CustomOp): return SampleOutput(success, probs, batch_next_token_ids) +@torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={}) +def flashinfer_top_k_top_p( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + # NOTE: we do not use min_p neither in CUDA nor in torch.compile + return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps) + + +@flashinfer_top_k_top_p.register_fake +def _( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bs = probs.shape[0] + return ( + torch.ones(bs, dtype=torch.bool, device=probs.device), + torch.zeros(bs, dtype=torch.int32, device=probs.device), + ) + + def top_k_top_p_min_p_sampling_from_probs_torch( probs: torch.Tensor, top_ks: torch.Tensor, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 40c87af88..4459213b0 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -46,8 +46,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): if isinstance(sub, CustomOp): if reverse: sub._forward_method = sub.forward_cuda + setattr(sub, "is_torch_compile", False) else: sub._forward_method = sub.forward_native + setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): _to_torch(sub, reverse) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 05a751365..26afe6600 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -523,7 +523,7 @@ class ModelRunner: if ( self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)) - and not batch.sampling_info.has_bias() + and batch.sampling_info.can_run_in_cuda_graph() ): return self.cuda_graph_runner.replay(batch) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 38b6701c7..20b1968d2 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -34,12 +34,14 @@ class SamplingBatchInfo: linear_penalties: torch.Tensor = None scaling_penalties: torch.Tensor = None - def has_bias(self): + def can_run_in_cuda_graph(self): + # Vocab bias and min_ps are not supported in CUDA graph return ( - self.logit_bias is not None - or self.vocab_mask is not None - or self.linear_penalties is not None - or self.scaling_penalties is not None + self.logit_bias is None + and self.vocab_mask is None + and self.linear_penalties is None + and self.scaling_penalties is None + and not self.need_min_p_sampling ) @classmethod @@ -48,35 +50,29 @@ class SamplingBatchInfo: ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda") ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda") ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda") - ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda") return ret def __getitem__(self, key): if isinstance(key, slice): - # NOTE: We do not use cuda graph when there is bias tensors - assert not self.has_bias() + # NOTE:This method is only used in CUDA graph + assert self.can_run_in_cuda_graph() return SamplingBatchInfo( vocab_size=self.vocab_size, temperatures=self.temperatures[key], top_ps=self.top_ps[key], top_ks=self.top_ks[key], - min_ps=self.min_ps[key], - need_min_p_sampling=self.need_min_p_sampling, ) else: raise NotImplementedError def inplace_assign(self, bs: int, other: SamplingBatchInfo): - # NOTE: We do not use cuda graph when there is bias tensors - assert not self.has_bias() + # NOTE:This method is only used in CUDA graph + assert self.can_run_in_cuda_graph() self.vocab_size = other.vocab_size - self.need_min_p_sampling = other.need_min_p_sampling - self.temperatures[:bs] = other.temperatures self.top_ps[:bs] = other.top_ps self.top_ks[:bs] = other.top_ks - self.min_ps[:bs] = other.min_ps @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):