Fix tp worker only checking req[0] for stream (#546)

This commit is contained in:
Qubitium-modelcloud
2024-06-15 13:56:10 +08:00
committed by GitHub
parent 40e53d65cb
commit bbec01c9aa
2 changed files with 7 additions and 3 deletions

View File

@@ -303,6 +303,10 @@ class Batch:
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
# whether batch has at least 1 streaming request
def has_stream(self) -> bool:
return any(r.stream for r in self.reqs)
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda" device = "cuda"
bs = len(self.reqs) bs = len(self.reqs)

View File

@@ -5,7 +5,7 @@ import logging
import time import time
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List from typing import List, Optional
import rpyc import rpyc
import torch import torch
@@ -253,7 +253,7 @@ class ModelTpServer:
self.running_batch = None self.running_batch = None
break break
if self.out_pyobjs and self.running_batch.reqs[0].stream: if self.out_pyobjs and self.running_batch.has_stream():
break break
else: else:
# Check the available size # Check the available size
@@ -314,7 +314,7 @@ class ModelTpServer:
) )
self.forward_queue.append(req) self.forward_queue.append(req)
def get_new_fill_batch(self): def get_new_fill_batch(self) -> Optional[Batch]:
if ( if (
self.running_batch is not None self.running_batch is not None
and len(self.running_batch.reqs) > self.max_running_requests and len(self.running_batch.reqs) > self.max_running_requests