From ed27a6b99258c905502bdc7f37300ea060d9b9b1 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Wed, 3 Apr 2024 12:45:01 +0800 Subject: [PATCH] Revert "Eliminate 2 gpu ops during sampling when logit_bias is zero" (#345) --- .../sglang/srt/managers/router/infer_batch.py | 40 ++++--------------- 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index b5294a477..f001075bc 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -251,14 +251,10 @@ class Batch: ] = out_cache_loc[pt : pt + extend_lens[i]] pt += extend_lens[i] - # Handle logit bias but only allocate when needed - logit_bias = None + # Handle logit bias + logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device) for i in range(bs): if reqs[i].sampling_params.dtype == "int": - if logit_bias is None: - logit_bias = torch.zeros( - (bs, vocab_size), dtype=torch.float32, device=device - ) logit_bias[i] = int_token_logit_bias # Set fields @@ -437,12 +433,9 @@ class Batch: "presence_penalties", "logit_bias", ]: - self_val = getattr(self, item, None) - # logit_bias can be None - if self_val is not None: - setattr(self, item, self_val[new_indices]) + setattr(self, item, getattr(self, item)[new_indices]) - def merge(self, other: "Batch"): + def merge(self, other): self.reqs.extend(other.reqs) self.req_pool_indices = torch.concat( @@ -463,34 +456,17 @@ class Batch: "top_ks", "frequency_penalties", "presence_penalties", + "logit_bias", ]: - self_val = getattr(self, item, None) - other_val = getattr(other, item, None) - setattr(self, item, torch.concat([self_val, other_val])) - - # logit_bias can be None - if self.logit_bias is not None or other.logit_bias is not None: - vocab_size = ( - self.logit_bias.shape[1] - if self.logit_bias is not None - else other.logit_bias.shape[1] + setattr( + self, item, torch.concat([getattr(self, item), getattr(other, item)]) ) - if self.logit_bias is None: - self.logit_bias = torch.zeros( - (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - if other.logit_bias is None: - other.logit_bias = torch.zeros( - (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" - ) - self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) def sample(self, logits: torch.Tensor): # Post process logits logits = logits.contiguous() logits.div_(self.temperatures) - if self.logit_bias is not None: - logits.add_(self.logit_bias) + logits.add_(self.logit_bias) has_regex = any(req.regex_fsm is not None for req in self.reqs) if has_regex: