Clean up wrapper in flashinfer backend (#2638)
This commit is contained in:
@@ -90,7 +90,7 @@ from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test retract decode
|
||||
# Test retract decode for debugging purposes
|
||||
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
|
||||
|
||||
@@ -129,12 +129,12 @@ class Scheduler:
|
||||
)
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
# Directly send to the tokenizer/api
|
||||
# Directly send to the TokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
||||
)
|
||||
else:
|
||||
# Send to the detokenizer
|
||||
# Send to the DetokenizerManager
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
||||
)
|
||||
@@ -385,7 +385,8 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
if self.server_args.enable_dp_attention:
|
||||
|
||||
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
@@ -394,7 +395,7 @@ class Scheduler:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
else:
|
||||
# Self-check and re-init some states when the server is idle
|
||||
# When the server is idle, so self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -411,12 +412,13 @@ class Scheduler:
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# A dummy first batch to start the pipeline for overlap scheduler.
|
||||
# Create a dummy first batch to start the pipeline for overlap scheduler.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
@@ -426,19 +428,21 @@ class Scheduler:
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# Self-check and re-init some states when the server is idle
|
||||
# When the server is idle, so self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
def recv_requests(self):
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
recv_reqs = []
|
||||
|
||||
@@ -812,6 +816,8 @@ class Scheduler:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
if self.server_args.prefill_only_one_req:
|
||||
break
|
||||
|
||||
# Update waiting queue
|
||||
can_run_list = adder.can_run_list
|
||||
@@ -1528,18 +1534,20 @@ def run_scheduler_process(
|
||||
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
||||
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
||||
|
||||
# Configue the logger
|
||||
if dp_rank is None:
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
else:
|
||||
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
||||
suppress_other_loggers()
|
||||
|
||||
# set cpu affinity to this gpu process
|
||||
# Set cpu affinity to this gpu process
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||
|
||||
suppress_other_loggers()
|
||||
parent_process = psutil.Process().parent()
|
||||
|
||||
# Create a scheduler and run the event loop
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||
pipe_writer.send(
|
||||
|
||||
Reference in New Issue
Block a user