Move sgl.Runtime under sglang/lang (#2990)
This commit is contained in:
@@ -34,6 +34,7 @@ import zmq
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -149,9 +150,7 @@ class Scheduler:
|
||||
else 1
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
# Distributed rank info
|
||||
self.dp_size = server_args.dp_size
|
||||
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
||||
compute_dp_attention_world_info(
|
||||
@@ -162,6 +161,8 @@ class Scheduler:
|
||||
)
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
if self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||
@@ -243,7 +244,7 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
|
||||
# Launch worker for speculative decoding if need
|
||||
# Launch a worker for speculative decoding if needed
|
||||
if self.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
|
||||
@@ -316,6 +317,8 @@ class Scheduler:
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
self.spec_num_total_accepted_tokens = 0
|
||||
self.spec_num_total_forward_ct = 0
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.stream_interval = server_args.stream_interval
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
@@ -337,28 +340,9 @@ class Scheduler:
|
||||
# Init the grammar backend for constrained generation
|
||||
self.grammar_queue: List[Req] = []
|
||||
if not server_args.skip_tokenizer_init:
|
||||
if server_args.grammar_backend == "outlines":
|
||||
from sglang.srt.constrained.outlines_backend import (
|
||||
OutlinesGrammarBackend,
|
||||
)
|
||||
|
||||
self.grammar_backend = OutlinesGrammarBackend(
|
||||
self.tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
allow_jump_forward=not server_args.disable_jump_forward,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
from sglang.srt.constrained.xgrammar_backend import (
|
||||
XGrammarGrammarBackend,
|
||||
)
|
||||
|
||||
self.grammar_backend = XGrammarGrammarBackend(
|
||||
self.tokenizer, vocab_size=self.model_config.vocab_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid grammar backend: {server_args.grammar_backend}"
|
||||
)
|
||||
self.grammar_backend = create_grammar_backend(
|
||||
server_args, self.tokenizer, self.model_config.vocab_size
|
||||
)
|
||||
else:
|
||||
self.grammar_backend = None
|
||||
|
||||
@@ -424,7 +408,8 @@ class Scheduler:
|
||||
},
|
||||
)
|
||||
|
||||
self._dispatcher = TypeBasedDispatcher(
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||
@@ -480,10 +465,6 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -506,10 +487,6 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -517,7 +494,7 @@ class Scheduler:
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap scheduler.
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
@@ -593,7 +570,7 @@ class Scheduler:
|
||||
|
||||
def process_input_requests(self, recv_reqs: List):
|
||||
for recv_req in recv_reqs:
|
||||
output = self._dispatcher(recv_req)
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
@@ -798,15 +775,32 @@ class Scheduler:
|
||||
self.num_generated_tokens = 0
|
||||
self.last_decode_stats_tic = time.time()
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
logger.info(
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
msg = (
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
else:
|
||||
accept_length = (
|
||||
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
||||
)
|
||||
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
||||
msg = (
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"accept len: {accept_length:.2f}, "
|
||||
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
|
||||
logger.info(msg)
|
||||
if self.enable_metrics:
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
@@ -855,16 +849,23 @@ class Scheduler:
|
||||
else:
|
||||
self.running_batch.merge_batch(self.last_batch)
|
||||
|
||||
# Run prefill first if possible
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
if new_batch is not None:
|
||||
return new_batch
|
||||
# Run prefill first if possible
|
||||
ret = new_batch
|
||||
else:
|
||||
# Run decode
|
||||
if self.running_batch is None:
|
||||
ret = None
|
||||
else:
|
||||
self.running_batch = self.update_running_batch(self.running_batch)
|
||||
ret = self.running_batch
|
||||
|
||||
# Run decode
|
||||
if self.running_batch is None:
|
||||
return None
|
||||
self.running_batch = self.update_running_batch(self.running_batch)
|
||||
return self.running_batch
|
||||
# Handle DP attention
|
||||
if self.server_args.enable_dp_attention:
|
||||
ret = self.prepare_dp_attn_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Check if the grammar is ready in the grammar queue
|
||||
@@ -1053,6 +1054,10 @@ class Scheduler:
|
||||
model_worker_batch,
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.spec_num_total_accepted_tokens += (
|
||||
num_accepted_tokens + batch.batch_size()
|
||||
)
|
||||
self.spec_num_total_forward_ct += batch.batch_size()
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
else:
|
||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
||||
|
||||
Reference in New Issue
Block a user