diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 0cab5455d..653225d68 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -303,6 +303,10 @@ class Batch: def is_empty(self): return len(self.reqs) == 0 + # whether batch has at least 1 streaming request + def has_stream(self) -> bool: + return any(r.stream for r in self.reqs) + def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): device = "cuda" bs = len(self.reqs) diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 2f3e86593..d8fee6537 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -5,7 +5,7 @@ import logging import time import warnings from concurrent.futures import ThreadPoolExecutor -from typing import List +from typing import List, Optional import rpyc import torch @@ -253,7 +253,7 @@ class ModelTpServer: self.running_batch = None break - if self.out_pyobjs and self.running_batch.reqs[0].stream: + if self.out_pyobjs and self.running_batch.has_stream(): break else: # Check the available size @@ -314,7 +314,7 @@ class ModelTpServer: ) self.forward_queue.append(req) - def get_new_fill_batch(self): + def get_new_fill_batch(self) -> Optional[Batch]: if ( self.running_batch is not None and len(self.running_batch.reqs) > self.max_running_requests