Enable overlap scheduler by default for the triton attention backend (#2105)
This commit is contained in:
@@ -53,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
||||
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
||||
|
||||
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
||||
total_num_tokens = forward_batch.seq_lens_sum
|
||||
attn_logits = torch.empty(
|
||||
(self.num_head, total_num_tokens),
|
||||
dtype=self.reduce_dtype,
|
||||
|
||||
@@ -170,18 +170,9 @@ class Scheduler:
|
||||
if not self.is_generation:
|
||||
self.enable_overlap = False
|
||||
logger.info("Overlap scheduler is disabled for embedding models.")
|
||||
if (
|
||||
server_args.attention_backend == "triton"
|
||||
or server_args.enable_double_sparsity
|
||||
or (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
)
|
||||
):
|
||||
self.enable_overlap = False
|
||||
logger.info(
|
||||
"Overlap scheduler is disabled if using triton attention backend."
|
||||
)
|
||||
|
||||
if self.enable_overlap:
|
||||
self.disable_jump_forward = True
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.enable_overlap:
|
||||
|
||||
@@ -94,10 +94,21 @@ class TpModelWorkerClient:
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_thread_func_(self):
|
||||
batch_pt = 0
|
||||
batch_lists = [None] * 2
|
||||
|
||||
while True:
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
if not model_worker_batch:
|
||||
break
|
||||
|
||||
# Keep a reference of model_worker_batch by storing it into a list.
|
||||
# Otherwise, the tensor members of model_worker_batch will be released
|
||||
# by pytorch and cause CUDA illegal memory access errors.
|
||||
batch_lists[batch_pt % 2] = model_worker_batch
|
||||
batch_pt += 1
|
||||
|
||||
# Create event
|
||||
self.launch_done = threading.Event()
|
||||
copy_done = torch.cuda.Event()
|
||||
|
||||
|
||||
@@ -170,7 +170,6 @@ class CudaGraphRunner:
|
||||
self.encoder_lens = None
|
||||
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens = [0] * self.tp_size
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.tp_size,
|
||||
@@ -264,10 +263,10 @@ class CudaGraphRunner:
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
||||
global_num_tokens = [bs] * self.tp_size
|
||||
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
||||
else:
|
||||
self.global_num_tokens = None
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
|
||||
# Attention backend
|
||||
@@ -296,7 +295,7 @@ class CudaGraphRunner:
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=clamp_position(seq_lens),
|
||||
mrope_positions=mrope_positions,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens=global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
)
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
@@ -348,8 +347,6 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
|
||||
@@ -174,17 +174,17 @@ class ServerArgs:
|
||||
self.cuda_graph_max_bs = 4
|
||||
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
|
||||
|
||||
# Choose kernel backends
|
||||
if not is_flashinfer_available():
|
||||
self.attention_backend = "triton"
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Default kernel backends
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "flashinfer"
|
||||
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = "flashinfer"
|
||||
|
||||
# Others
|
||||
if self.enable_dp_attention:
|
||||
self.dp_size = self.tp_size
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||
@@ -205,9 +205,6 @@ class ServerArgs:
|
||||
)
|
||||
self.disable_overlap_schedule = True
|
||||
|
||||
if not self.disable_overlap_schedule:
|
||||
self.disable_jump_forward = True
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and port args
|
||||
|
||||
@@ -2,3 +2,4 @@
|
||||
|
||||
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
||||
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
|
||||
kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}')
|
||||
|
||||
Reference in New Issue
Block a user