Improve type annotation and styles (#2926)
This commit is contained in:
@@ -226,8 +226,9 @@ class Req:
|
||||
else origin_input_ids # Before image padding
|
||||
)
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||
# Each decode stage's output ids
|
||||
self.output_ids = []
|
||||
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
||||
self.session_id = session_id
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
@@ -265,6 +266,7 @@ class Req:
|
||||
# Prefix info
|
||||
self.prefix_indices = []
|
||||
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
||||
# Updated if chunked.
|
||||
self.extend_input_len = 0
|
||||
self.last_node = None
|
||||
|
||||
@@ -280,10 +282,10 @@ class Req:
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
|
||||
# Logprobs (return value)
|
||||
self.input_token_logprobs_val = None
|
||||
self.input_token_logprobs_idx = None
|
||||
self.input_top_logprobs_val = None
|
||||
self.input_top_logprobs_idx = None
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||
self.input_top_logprobs_idx: Optional[List[int]] = None
|
||||
|
||||
if return_logprob:
|
||||
self.output_token_logprobs_val = []
|
||||
|
||||
@@ -22,8 +22,9 @@ import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
|
||||
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: LogitsProcessorOutput
|
||||
next_token_ids: List[int]
|
||||
bid: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingBatchResult:
|
||||
embeddings: torch.Tensor
|
||||
bid: int
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
@@ -411,16 +425,16 @@ class Scheduler:
|
||||
self.watchdog_last_time = time.time()
|
||||
|
||||
while True:
|
||||
current = time.time()
|
||||
if self.cur_batch is not None:
|
||||
if self.watchdog_last_forward_ct == self.forward_ct:
|
||||
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
|
||||
if current > self.watchdog_last_time + self.watchdog_timeout:
|
||||
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
||||
break
|
||||
else:
|
||||
self.watchdog_last_forward_ct = self.forward_ct
|
||||
self.watchdog_last_time = time.time()
|
||||
time.sleep(self.watchdog_timeout / 2)
|
||||
|
||||
self.watchdog_last_time = current
|
||||
time.sleep(self.watchdog_timeout // 2)
|
||||
# Wait sometimes so that the parent process can print the error.
|
||||
time.sleep(5)
|
||||
self.parent_process.send_signal(signal.SIGQUIT)
|
||||
@@ -1018,7 +1032,9 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
return batch
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
def run_batch(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
||||
"""Run a batch."""
|
||||
self.forward_ct += 1
|
||||
|
||||
@@ -1040,15 +1056,26 @@ class Scheduler:
|
||||
else:
|
||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
||||
batch.output_ids = next_token_ids
|
||||
ret = logits_output, next_token_ids, model_worker_batch.bid
|
||||
|
||||
ret = GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
bid=model_worker_batch.bid,
|
||||
)
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||
ret = embeddings, model_worker_batch.bid
|
||||
ret = EmbeddingBatchResult(
|
||||
embeddings=embeddings, bid=model_worker_batch.bid
|
||||
)
|
||||
return ret
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
def process_batch_result(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
@@ -1057,17 +1084,29 @@ class Scheduler:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_batch_result(result[-1])
|
||||
self.tp_worker.resolve_batch_result(result.bid)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
def process_batch_result_prefill(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
):
|
||||
skip_stream_req = None
|
||||
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids, bid = result
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
bid,
|
||||
) = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
@@ -1125,7 +1164,7 @@ class Scheduler:
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
else: # embedding or reward model
|
||||
embeddings, bid = result
|
||||
embeddings, bid = result.embeddings, result.bid
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
@@ -1149,8 +1188,16 @@ class Scheduler:
|
||||
|
||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids, bid = result
|
||||
def process_batch_result_decode(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: GenerationBatchResult,
|
||||
):
|
||||
logits_output, next_token_ids, bid = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
if self.enable_overlap:
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_group,
|
||||
get_attention_tp_size,
|
||||
initialize_dp_attention,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -532,7 +533,7 @@ class ModelRunner:
|
||||
)
|
||||
else:
|
||||
cell_size = (
|
||||
self.model_config.get_num_kv_heads(self.tp_size)
|
||||
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
||||
* self.model_config.head_dim
|
||||
* self.model_config.num_hidden_layers
|
||||
* 2
|
||||
@@ -626,7 +627,7 @@ class ModelRunner:
|
||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
@@ -637,7 +638,7 @@ class ModelRunner:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
|
||||
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
|
||||
@@ -842,7 +842,6 @@ class Engine:
|
||||
generator = ret.body_iterator
|
||||
|
||||
async def generator_wrapper():
|
||||
|
||||
offset = 0
|
||||
|
||||
while True:
|
||||
|
||||
@@ -239,8 +239,8 @@ class ServerArgs:
|
||||
|
||||
# Others
|
||||
if self.enable_dp_attention:
|
||||
assert self.tp_size % self.dp_size == 0
|
||||
self.dp_size = self.tp_size
|
||||
assert self.tp_size % self.dp_size == 0
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user