Improve type annotation and styles (#2926)
This commit is contained in:
@@ -103,6 +103,7 @@ def tree_search(s, question, num_branches):
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)
|
lines = read_jsonl(args.data_path)
|
||||||
|
lines = list(lines)
|
||||||
|
|
||||||
# Construct prompts
|
# Construct prompts
|
||||||
num_branches = 2
|
num_branches = 2
|
||||||
|
|||||||
@@ -226,8 +226,9 @@ class Req:
|
|||||||
else origin_input_ids # Before image padding
|
else origin_input_ids # Before image padding
|
||||||
)
|
)
|
||||||
self.origin_input_ids = origin_input_ids
|
self.origin_input_ids = origin_input_ids
|
||||||
self.output_ids = [] # Each decode stage's output ids
|
# Each decode stage's output ids
|
||||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
self.output_ids = []
|
||||||
|
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.input_embeds = input_embeds
|
self.input_embeds = input_embeds
|
||||||
|
|
||||||
@@ -265,6 +266,7 @@ class Req:
|
|||||||
# Prefix info
|
# Prefix info
|
||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
||||||
|
# Updated if chunked.
|
||||||
self.extend_input_len = 0
|
self.extend_input_len = 0
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
|
||||||
@@ -280,10 +282,10 @@ class Req:
|
|||||||
self.top_logprobs_num = top_logprobs_num
|
self.top_logprobs_num = top_logprobs_num
|
||||||
|
|
||||||
# Logprobs (return value)
|
# Logprobs (return value)
|
||||||
self.input_token_logprobs_val = None
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||||
self.input_token_logprobs_idx = None
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||||
self.input_top_logprobs_val = None
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||||
self.input_top_logprobs_idx = None
|
self.input_top_logprobs_idx: Optional[List[int]] = None
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
self.output_token_logprobs_val = []
|
self.output_token_logprobs_val = []
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
from dataclasses import dataclass
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import setproctitle
|
import setproctitle
|
||||||
@@ -102,6 +103,19 @@ logger = logging.getLogger(__name__)
|
|||||||
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationBatchResult:
|
||||||
|
logits_output: LogitsProcessorOutput
|
||||||
|
next_token_ids: List[int]
|
||||||
|
bid: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingBatchResult:
|
||||||
|
embeddings: torch.Tensor
|
||||||
|
bid: int
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||||
|
|
||||||
@@ -411,16 +425,16 @@ class Scheduler:
|
|||||||
self.watchdog_last_time = time.time()
|
self.watchdog_last_time = time.time()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
current = time.time()
|
||||||
if self.cur_batch is not None:
|
if self.cur_batch is not None:
|
||||||
if self.watchdog_last_forward_ct == self.forward_ct:
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
||||||
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
|
if current > self.watchdog_last_time + self.watchdog_timeout:
|
||||||
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.watchdog_last_forward_ct = self.forward_ct
|
self.watchdog_last_forward_ct = self.forward_ct
|
||||||
self.watchdog_last_time = time.time()
|
self.watchdog_last_time = current
|
||||||
time.sleep(self.watchdog_timeout / 2)
|
time.sleep(self.watchdog_timeout // 2)
|
||||||
|
|
||||||
# Wait sometimes so that the parent process can print the error.
|
# Wait sometimes so that the parent process can print the error.
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
self.parent_process.send_signal(signal.SIGQUIT)
|
self.parent_process.send_signal(signal.SIGQUIT)
|
||||||
@@ -1018,7 +1032,9 @@ class Scheduler:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def run_batch(self, batch: ScheduleBatch):
|
def run_batch(
|
||||||
|
self, batch: ScheduleBatch
|
||||||
|
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
||||||
"""Run a batch."""
|
"""Run a batch."""
|
||||||
self.forward_ct += 1
|
self.forward_ct += 1
|
||||||
|
|
||||||
@@ -1040,15 +1056,26 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
||||||
batch.output_ids = next_token_ids
|
batch.output_ids = next_token_ids
|
||||||
ret = logits_output, next_token_ids, model_worker_batch.bid
|
|
||||||
|
ret = GenerationBatchResult(
|
||||||
|
logits_output=logits_output,
|
||||||
|
next_token_ids=next_token_ids,
|
||||||
|
bid=model_worker_batch.bid,
|
||||||
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
assert batch.extend_num_tokens != 0
|
assert batch.extend_num_tokens != 0
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
||||||
ret = embeddings, model_worker_batch.bid
|
ret = EmbeddingBatchResult(
|
||||||
|
embeddings=embeddings, bid=model_worker_batch.bid
|
||||||
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
def process_batch_result(
|
||||||
|
self,
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||||
|
):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
self.process_batch_result_decode(batch, result)
|
self.process_batch_result_decode(batch, result)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
@@ -1057,17 +1084,29 @@ class Scheduler:
|
|||||||
self.process_batch_result_prefill(batch, result)
|
self.process_batch_result_prefill(batch, result)
|
||||||
elif batch.forward_mode.is_idle():
|
elif batch.forward_mode.is_idle():
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
self.tp_worker.resolve_batch_result(result[-1])
|
self.tp_worker.resolve_batch_result(result.bid)
|
||||||
elif batch.forward_mode.is_dummy_first():
|
elif batch.forward_mode.is_dummy_first():
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
self.current_stream.synchronize()
|
self.current_stream.synchronize()
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(
|
||||||
|
self,
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||||
|
):
|
||||||
skip_stream_req = None
|
skip_stream_req = None
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
logits_output, next_token_ids, bid = result
|
(
|
||||||
|
logits_output,
|
||||||
|
next_token_ids,
|
||||||
|
bid,
|
||||||
|
) = (
|
||||||
|
result.logits_output,
|
||||||
|
result.next_token_ids,
|
||||||
|
result.bid,
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||||
@@ -1125,7 +1164,7 @@ class Scheduler:
|
|||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
embeddings, bid = result
|
embeddings, bid = result.embeddings, result.bid
|
||||||
embeddings = embeddings.tolist()
|
embeddings = embeddings.tolist()
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
@@ -1149,8 +1188,16 @@ class Scheduler:
|
|||||||
|
|
||||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(
|
||||||
logits_output, next_token_ids, bid = result
|
self,
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
result: GenerationBatchResult,
|
||||||
|
):
|
||||||
|
logits_output, next_token_ids, bid = (
|
||||||
|
result.logits_output,
|
||||||
|
result.next_token_ids,
|
||||||
|
result.bid,
|
||||||
|
)
|
||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
|
|||||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_tp_group,
|
get_attention_tp_group,
|
||||||
|
get_attention_tp_size,
|
||||||
initialize_dp_attention,
|
initialize_dp_attention,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
@@ -532,7 +533,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cell_size = (
|
cell_size = (
|
||||||
self.model_config.get_num_kv_heads(self.tp_size)
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
||||||
* self.model_config.head_dim
|
* self.model_config.head_dim
|
||||||
* self.model_config.num_hidden_layers
|
* self.model_config.num_hidden_layers
|
||||||
* 2
|
* 2
|
||||||
@@ -626,7 +627,7 @@ class ModelRunner:
|
|||||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@@ -637,7 +638,7 @@ class ModelRunner:
|
|||||||
self.token_to_kv_pool = MHATokenToKVPool(
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
|
|||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
session_params: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(BaseModel):
|
class CompletionResponseChoice(BaseModel):
|
||||||
@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
|
session_params: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
class FunctionResponse(BaseModel):
|
class FunctionResponse(BaseModel):
|
||||||
|
|||||||
@@ -842,7 +842,6 @@ class Engine:
|
|||||||
generator = ret.body_iterator
|
generator = ret.body_iterator
|
||||||
|
|
||||||
async def generator_wrapper():
|
async def generator_wrapper():
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -239,8 +239,8 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Others
|
# Others
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
assert self.tp_size % self.dp_size == 0
|
|
||||||
self.dp_size = self.tp_size
|
self.dp_size = self.tp_size
|
||||||
|
assert self.tp_size % self.dp_size == 0
|
||||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user