Improve type annotation and styles (#2926)

This commit is contained in:
Lianmin Zheng
2025-01-16 12:51:11 -08:00
committed by GitHub
parent a883f0790d
commit bc6915e3b9
7 changed files with 78 additions and 26 deletions

View File

@@ -22,8 +22,9 @@ import time
import warnings
from collections import deque
from concurrent import futures
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import psutil
import setproctitle
@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
@dataclass
class GenerationBatchResult:
logits_output: LogitsProcessorOutput
next_token_ids: List[int]
bid: int
@dataclass
class EmbeddingBatchResult:
embeddings: torch.Tensor
bid: int
class Scheduler:
"""A scheduler that manages a tensor parallel GPU worker."""
@@ -411,16 +425,16 @@ class Scheduler:
self.watchdog_last_time = time.time()
while True:
current = time.time()
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)
self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2)
# Wait sometimes so that the parent process can print the error.
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
@@ -1018,7 +1032,9 @@ class Scheduler:
batch.prepare_for_decode()
return batch
def run_batch(self, batch: ScheduleBatch):
def run_batch(
self, batch: ScheduleBatch
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
"""Run a batch."""
self.forward_ct += 1
@@ -1040,15 +1056,26 @@ class Scheduler:
else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
batch.output_ids = next_token_ids
ret = logits_output, next_token_ids, model_worker_batch.bid
ret = GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
bid=model_worker_batch.bid,
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = embeddings, model_worker_batch.bid
ret = EmbeddingBatchResult(
embeddings=embeddings, bid=model_worker_batch.bid
)
return ret
def process_batch_result(self, batch: ScheduleBatch, result):
def process_batch_result(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
@@ -1057,17 +1084,29 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result[-1])
self.tp_worker.resolve_batch_result(result.bid)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
def process_batch_result_prefill(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
skip_stream_req = None
if self.is_generation:
logits_output, next_token_ids, bid = result
(
logits_output,
next_token_ids,
bid,
) = (
result.logits_output,
result.next_token_ids,
result.bid,
)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
@@ -1125,7 +1164,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
embeddings, bid = result
embeddings, bid = result.embeddings, result.bid
embeddings = embeddings.tolist()
# Check finish conditions
@@ -1149,8 +1188,16 @@ class Scheduler:
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result
def process_batch_result_decode(
self,
batch: ScheduleBatch,
result: GenerationBatchResult,
):
logits_output, next_token_ids, bid = (
result.logits_output,
result.next_token_ids,
result.bid,
)
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap: