From fbb4754cb8c6585763ab631231508e84e6c287e2 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 10 Sep 2024 13:10:36 -0700 Subject: [PATCH] Fix vocab mask update bug (#1376) --- python/sglang/srt/managers/schedule_batch.py | 2 - .../srt/model_executor/forward_batch_info.py | 3 +- .../srt/sampling/sampling_batch_info.py | 46 +++++++++++-------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6c6b7f842..2e2489cd2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -652,8 +652,6 @@ class ScheduleBatch: self.req_pool_indices, self.seq_lens - 1 ] = self.out_cache_loc - self.sampling_info.update_regex_vocab_mask(self) - def filter_batch(self, unfinished_indices: List[int]): if unfinished_indices is None or len(unfinished_indices) == 0: # Filter out all requests diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a6ad63ce1..c1fb23357 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -195,7 +195,8 @@ class InputMetadata: top_logprobs_nums=batch.top_logprobs_nums, ) - ret.sampling_info.prepare_penalties() + ret.sampling_info.update_penalties() + ret.sampling_info.update_regex_vocab_mask(batch) ret.compute_positions(batch) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 20b1968d2..622f27df1 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -34,6 +34,9 @@ class SamplingBatchInfo: linear_penalties: torch.Tensor = None scaling_penalties: torch.Tensor = None + def __len__(self): + return len(self.temperatures) + def can_run_in_cuda_graph(self): # Vocab bias and min_ps are not supported in CUDA graph return ( @@ -118,11 +121,9 @@ class SamplingBatchInfo: # Handle logit bias but only allocate when needed ret.logit_bias = None - ret.update_regex_vocab_mask(batch) - return ret - def prepare_penalties(self): + def update_penalties(self): self.scaling_penalties = None self.linear_penalties = None @@ -174,6 +175,26 @@ class SamplingBatchInfo: if self_val is not None: # logit_bias can be None setattr(self, item, self_val[new_indices]) + @staticmethod + def merge_bias_tensor( + lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0 + ): + # bias tensor can be None + if lhs is not None or rhs is not None: + shape, dtype = None, None + if lhs is not None: + shape, dtype = lhs.shape[1:], lhs.dtype + else: + shape, dtype = rhs.shape[1:], rhs.dtype + with torch.dtype(dtype): + if lhs is None: + lhs = torch.empty((bs1, *shape), device="cuda").fill_(default) + if rhs is None: + rhs = torch.empty((bs2, *shape), device="cuda").fill_(default) + return torch.cat([lhs, rhs]) + + return None + def merge(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) @@ -187,19 +208,6 @@ class SamplingBatchInfo: other_val = getattr(other, item, None) setattr(self, item, torch.concat([self_val, other_val])) - # logit_bias can be None - if self.logit_bias is not None or other.logit_bias is not None: - vocab_size = ( - self.logit_bias.shape[1] - if self.logit_bias is not None - else other.logit_bias.shape[1] - ) - if self.logit_bias is None: - self.logit_bias = torch.zeros( - (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - if other.logit_bias is None: - other.logit_bias = torch.zeros( - (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) + self.logit_bias = SamplingBatchInfo.merge_bias_tensor( + self.logit_bias, other.logit_bias, len(self), len(other) + )