Move sgl.Runtime under sglang/lang (#2990)

This commit is contained in:
Lianmin Zheng
2025-01-19 17:10:29 -08:00
committed by GitHub
parent e403d23757
commit 61f42b5732
17 changed files with 267 additions and 329 deletions

View File

@@ -34,6 +34,7 @@ import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -149,9 +150,7 @@ class Scheduler:
else 1
)
# Init inter-process communication
context = zmq.Context(2)
# Distributed rank info
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info(
@@ -162,6 +161,8 @@ class Scheduler:
)
)
# Init inter-process communication
context = zmq.Context(2)
if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
@@ -243,7 +244,7 @@ class Scheduler:
nccl_port=port_args.nccl_port,
)
# Launch worker for speculative decoding if need
# Launch a worker for speculative decoding if needed
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -316,6 +317,8 @@ class Scheduler:
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream()
@@ -337,28 +340,9 @@ class Scheduler:
# Init the grammar backend for constrained generation
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
if server_args.grammar_backend == "outlines":
from sglang.srt.constrained.outlines_backend import (
OutlinesGrammarBackend,
)
self.grammar_backend = OutlinesGrammarBackend(
self.tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward,
)
elif server_args.grammar_backend == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import (
XGrammarGrammarBackend,
)
self.grammar_backend = XGrammarGrammarBackend(
self.tokenizer, vocab_size=self.model_config.vocab_size
)
else:
raise ValueError(
f"Invalid grammar backend: {server_args.grammar_backend}"
)
self.grammar_backend = create_grammar_backend(
server_args, self.tokenizer, self.model_config.vocab_size
)
else:
self.grammar_backend = None
@@ -424,7 +408,8 @@ class Scheduler:
},
)
self._dispatcher = TypeBasedDispatcher(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
@@ -480,10 +465,6 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
if batch:
@@ -506,10 +487,6 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
if batch:
@@ -517,7 +494,7 @@ class Scheduler:
result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap scheduler.
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
@@ -593,7 +570,7 @@ class Scheduler:
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
output = self._dispatcher(recv_req)
output = self._request_dispatcher(recv_req)
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
@@ -798,15 +775,32 @@ class Scheduler:
self.num_generated_tokens = 0
self.last_decode_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
if self.spec_algorithm.is_none():
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
else:
accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
msg = (
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"accept len: {accept_length:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
@@ -855,16 +849,23 @@ class Scheduler:
else:
self.running_batch.merge_batch(self.last_batch)
# Run prefill first if possible
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
return new_batch
# Run prefill first if possible
ret = new_batch
else:
# Run decode
if self.running_batch is None:
ret = None
else:
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch
# Run decode
if self.running_batch is None:
return None
self.running_batch = self.update_running_batch(self.running_batch)
return self.running_batch
# Handle DP attention
if self.server_args.enable_dp_attention:
ret = self.prepare_dp_attn_batch(ret)
return ret
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar is ready in the grammar queue
@@ -1053,6 +1054,10 @@ class Scheduler:
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!"