diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index f001075bc..b5294a477 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -251,10 +251,14 @@ class Batch: ] = out_cache_loc[pt : pt + extend_lens[i]] pt += extend_lens[i] - # Handle logit bias - logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device) + # Handle logit bias but only allocate when needed + logit_bias = None for i in range(bs): if reqs[i].sampling_params.dtype == "int": + if logit_bias is None: + logit_bias = torch.zeros( + (bs, vocab_size), dtype=torch.float32, device=device + ) logit_bias[i] = int_token_logit_bias # Set fields @@ -433,9 +437,12 @@ class Batch: "presence_penalties", "logit_bias", ]: - setattr(self, item, getattr(self, item)[new_indices]) + self_val = getattr(self, item, None) + # logit_bias can be None + if self_val is not None: + setattr(self, item, self_val[new_indices]) - def merge(self, other): + def merge(self, other: "Batch"): self.reqs.extend(other.reqs) self.req_pool_indices = torch.concat( @@ -456,17 +463,34 @@ class Batch: "top_ks", "frequency_penalties", "presence_penalties", - "logit_bias", ]: - setattr( - self, item, torch.concat([getattr(self, item), getattr(other, item)]) + self_val = getattr(self, item, None) + other_val = getattr(other, item, None) + setattr(self, item, torch.concat([self_val, other_val])) + + # logit_bias can be None + if self.logit_bias is not None or other.logit_bias is not None: + vocab_size = ( + self.logit_bias.shape[1] + if self.logit_bias is not None + else other.logit_bias.shape[1] ) + if self.logit_bias is None: + self.logit_bias = torch.zeros( + (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda" + ) + if other.logit_bias is None: + other.logit_bias = torch.zeros( + (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda" + ) + self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) def sample(self, logits: torch.Tensor): # Post process logits logits = logits.contiguous() logits.div_(self.temperatures) - logits.add_(self.logit_bias) + if self.logit_bias is not None: + logits.add_(self.logit_bias) has_regex = any(req.regex_fsm is not None for req in self.reqs) if has_regex: