Improve type annotation and styles (#2926)
This commit is contained in:
@@ -22,8 +22,9 @@ import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
|
||||
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: LogitsProcessorOutput
|
||||
next_token_ids: List[int]
|
||||
bid: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingBatchResult:
|
||||
embeddings: torch.Tensor
|
||||
bid: int
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
@@ -411,16 +425,16 @@ class Scheduler:
|
||||
self.watchdog_last_time = time.time()
|
||||
|
||||
while True:
|
||||
current = time.time()
|
||||
if self.cur_batch is not None:
|
||||
if self.watchdog_last_forward_ct == self.forward_ct:
|
||||
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
|
||||
if current > self.watchdog_last_time + self.watchdog_timeout:
|
||||
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
||||
break
|
||||
else:
|
||||
self.watchdog_last_forward_ct = self.forward_ct
|
||||
self.watchdog_last_time = time.time()
|
||||
time.sleep(self.watchdog_timeout / 2)
|
||||
|
||||
self.watchdog_last_time = current
|
||||
time.sleep(self.watchdog_timeout // 2)
|
||||
# Wait sometimes so that the parent process can print the error.
|
||||
time.sleep(5)
|
||||
self.parent_process.send_signal(signal.SIGQUIT)
|
||||
@@ -1018,7 +1032,9 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
return batch
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
def run_batch(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
||||
"""Run a batch."""
|
||||
self.forward_ct += 1
|
||||
|
||||
@@ -1040,15 +1056,26 @@ class Scheduler:
|
||||
else:
|
||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
||||
batch.output_ids = next_token_ids
|
||||
ret = logits_output, next_token_ids, model_worker_batch.bid
|
||||
|
||||
ret = GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
bid=model_worker_batch.bid,
|
||||
)
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = embeddings, model_worker_batch.bid
|
||||
ret = EmbeddingBatchResult(
|
||||
embeddings=embeddings, bid=model_worker_batch.bid
|
||||
)
|
||||
return ret
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
def process_batch_result(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
@@ -1057,17 +1084,29 @@ class Scheduler:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_batch_result(result[-1])
|
||||
self.tp_worker.resolve_batch_result(result.bid)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
def process_batch_result_prefill(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
):
|
||||
skip_stream_req = None
|
||||
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids, bid = result
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
bid,
|
||||
) = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
@@ -1125,7 +1164,7 @@ class Scheduler:
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
else: # embedding or reward model
|
||||
embeddings, bid = result
|
||||
embeddings, bid = result.embeddings, result.bid
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
@@ -1149,8 +1188,16 @@ class Scheduler:
|
||||
|
||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids, bid = result
|
||||
def process_batch_result_decode(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: GenerationBatchResult,
|
||||
):
|
||||
logits_output, next_token_ids, bid = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
if self.enable_overlap:
|
||||
|
||||
Reference in New Issue
Block a user