Fix the overlap for xgrammar (#2377)
This commit is contained in:
@@ -114,9 +114,6 @@ class Scheduler:
|
||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
|
||||
# Session info
|
||||
self.sessions = {}
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
@@ -259,6 +256,10 @@ class Scheduler:
|
||||
self.num_generated_tokens = 0
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.stream_interval = server_args.stream_interval
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
|
||||
# Session info
|
||||
self.sessions = {}
|
||||
|
||||
# Init chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
@@ -356,6 +357,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
def watchdog_thread(self):
|
||||
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
||||
self.watchdog_last_forward_ct = 0
|
||||
self.watchdog_last_time = time.time()
|
||||
|
||||
@@ -433,61 +435,6 @@ class Scheduler:
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||
# Check if other DP workers have running batches
|
||||
if local_batch is None:
|
||||
num_tokens = 0
|
||||
elif local_batch.forward_mode.is_decode():
|
||||
num_tokens = local_batch.batch_size()
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
|
||||
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
||||
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
global_num_tokens,
|
||||
local_num_tokens,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
|
||||
if local_batch is None and global_num_tokens.max().item() > 0:
|
||||
local_batch = self.get_idle_batch()
|
||||
|
||||
if local_batch is not None:
|
||||
local_batch.global_num_tokens = global_num_tokens.tolist()
|
||||
|
||||
# Check forward mode for cuda graph
|
||||
if not self.server_args.disable_cuda_graph:
|
||||
forward_mode_state = torch.tensor(
|
||||
(
|
||||
1
|
||||
if local_batch.forward_mode.is_decode()
|
||||
or local_batch.forward_mode.is_idle()
|
||||
else 0
|
||||
),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
forward_mode_state,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
||||
|
||||
return local_batch
|
||||
|
||||
def get_idle_batch(self):
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
[],
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
def recv_requests(self):
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
recv_reqs = []
|
||||
@@ -993,7 +940,7 @@ class Scheduler:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
@@ -1049,13 +996,14 @@ class Scheduler:
|
||||
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
req.grammar.finished = req.finished()
|
||||
else:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_being_chunked -= 1
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
else: # embedding or reward model
|
||||
@@ -1127,10 +1075,11 @@ class Scheduler:
|
||||
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
req.grammar.finished = req.finished()
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.get_device_module(self.device).current_stream().synchronize()
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
@@ -1328,6 +1277,61 @@ class Scheduler:
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||
# Check if other DP workers have running batches
|
||||
if local_batch is None:
|
||||
num_tokens = 0
|
||||
elif local_batch.forward_mode.is_decode():
|
||||
num_tokens = local_batch.batch_size()
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
|
||||
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
||||
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
global_num_tokens,
|
||||
local_num_tokens,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
|
||||
if local_batch is None and global_num_tokens.max().item() > 0:
|
||||
local_batch = self.get_idle_batch()
|
||||
|
||||
if local_batch is not None:
|
||||
local_batch.global_num_tokens = global_num_tokens.tolist()
|
||||
|
||||
# Check forward mode for cuda graph
|
||||
if not self.server_args.disable_cuda_graph:
|
||||
forward_mode_state = torch.tensor(
|
||||
(
|
||||
1
|
||||
if local_batch.forward_mode.is_decode()
|
||||
or local_batch.forward_mode.is_idle()
|
||||
else 0
|
||||
),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
forward_mode_state,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
||||
|
||||
return local_batch
|
||||
|
||||
def get_idle_batch(self):
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
[],
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
def move_ready_grammar_requests(self):
|
||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||
num_ready_reqs = 0
|
||||
@@ -1469,10 +1473,6 @@ def run_scheduler_process(
|
||||
dp_rank: Optional[int],
|
||||
pipe_writer,
|
||||
):
|
||||
# set cpu affinity to this gpu process
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||
|
||||
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
||||
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
||||
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
||||
@@ -1482,6 +1482,10 @@ def run_scheduler_process(
|
||||
else:
|
||||
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
||||
|
||||
# set cpu affinity to this gpu process
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||
|
||||
suppress_other_loggers()
|
||||
parent_process = psutil.Process().parent()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user