Clean up wrapper in flashinfer backend (#2638)

This commit is contained in:
Lianmin Zheng
2024-12-29 00:45:57 -08:00
committed by GitHub
parent fd34f2da35
commit 3815b23ccb
12 changed files with 197 additions and 94 deletions

View File

@@ -90,7 +90,7 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
# Test retract decode
# Test retract decode for debugging purposes
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
@@ -129,12 +129,12 @@ class Scheduler:
)
if server_args.skip_tokenizer_init:
# Directly send to the tokenizer/api
# Directly send to the TokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.tokenizer_ipc_name
)
else:
# Send to the detokenizer
# Send to the DetokenizerManager
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name
)
@@ -385,7 +385,8 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention:
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
@@ -394,7 +395,7 @@ class Scheduler:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
# Self-check and re-init some states when the server is idle
# When the server is idle, so self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
@@ -411,12 +412,13 @@ class Scheduler:
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
if self.last_batch is None:
# A dummy first batch to start the pipeline for overlap scheduler.
# Create a dummy first batch to start the pipeline for overlap scheduler.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
@@ -426,19 +428,21 @@ class Scheduler:
self.process_batch_result(tmp_batch, None)
if self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# Self-check and re-init some states when the server is idle
# When the server is idle, so self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
def recv_requests(self):
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = []
@@ -812,6 +816,8 @@ class Scheduler:
if res == AddReqResult.NO_TOKEN:
self.batch_is_full = True
break
if self.server_args.prefill_only_one_req:
break
# Update waiting queue
can_run_list = adder.can_run_list
@@ -1528,18 +1534,20 @@ def run_scheduler_process(
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"])
# Configue the logger
if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}")
else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
suppress_other_loggers()
# set cpu affinity to this gpu process
# Set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
suppress_other_loggers()
parent_process = psutil.Process().parent()
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send(