Enable overlap scheduler by default for the triton attention backend (#2105)
This commit is contained in:
@@ -53,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
|
||||||
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
|
||||||
|
|
||||||
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
|
total_num_tokens = forward_batch.seq_lens_sum
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(self.num_head, total_num_tokens),
|
(self.num_head, total_num_tokens),
|
||||||
dtype=self.reduce_dtype,
|
dtype=self.reduce_dtype,
|
||||||
|
|||||||
@@ -170,18 +170,9 @@ class Scheduler:
|
|||||||
if not self.is_generation:
|
if not self.is_generation:
|
||||||
self.enable_overlap = False
|
self.enable_overlap = False
|
||||||
logger.info("Overlap scheduler is disabled for embedding models.")
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
||||||
if (
|
|
||||||
server_args.attention_backend == "triton"
|
if self.enable_overlap:
|
||||||
or server_args.enable_double_sparsity
|
self.disable_jump_forward = True
|
||||||
or (
|
|
||||||
self.model_config.attention_arch == AttentionArch.MLA
|
|
||||||
and not self.server_args.disable_mla
|
|
||||||
)
|
|
||||||
):
|
|
||||||
self.enable_overlap = False
|
|
||||||
logger.info(
|
|
||||||
"Overlap scheduler is disabled if using triton attention backend."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Launch a tensor parallel worker
|
# Launch a tensor parallel worker
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
|
|||||||
@@ -94,10 +94,21 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward_thread_func_(self):
|
def forward_thread_func_(self):
|
||||||
|
batch_pt = 0
|
||||||
|
batch_lists = [None] * 2
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||||
if not model_worker_batch:
|
if not model_worker_batch:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Keep a reference of model_worker_batch by storing it into a list.
|
||||||
|
# Otherwise, the tensor members of model_worker_batch will be released
|
||||||
|
# by pytorch and cause CUDA illegal memory access errors.
|
||||||
|
batch_lists[batch_pt % 2] = model_worker_batch
|
||||||
|
batch_pt += 1
|
||||||
|
|
||||||
|
# Create event
|
||||||
self.launch_done = threading.Event()
|
self.launch_done = threading.Event()
|
||||||
copy_done = torch.cuda.Event()
|
copy_done = torch.cuda.Event()
|
||||||
|
|
||||||
|
|||||||
@@ -170,7 +170,6 @@ class CudaGraphRunner:
|
|||||||
self.encoder_lens = None
|
self.encoder_lens = None
|
||||||
|
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.global_num_tokens = [0] * self.tp_size
|
|
||||||
self.gathered_buffer = torch.zeros(
|
self.gathered_buffer = torch.zeros(
|
||||||
(
|
(
|
||||||
self.max_bs * self.tp_size,
|
self.max_bs * self.tp_size,
|
||||||
@@ -264,10 +263,10 @@ class CudaGraphRunner:
|
|||||||
mrope_positions = self.mrope_positions[:, :bs]
|
mrope_positions = self.mrope_positions[:, :bs]
|
||||||
|
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
global_num_tokens = [bs] * self.tp_size
|
||||||
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
||||||
else:
|
else:
|
||||||
self.global_num_tokens = None
|
global_num_tokens = None
|
||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
@@ -296,7 +295,7 @@ class CudaGraphRunner:
|
|||||||
top_logprobs_nums=[0] * bs,
|
top_logprobs_nums=[0] * bs,
|
||||||
positions=clamp_position(seq_lens),
|
positions=clamp_position(seq_lens),
|
||||||
mrope_positions=mrope_positions,
|
mrope_positions=mrope_positions,
|
||||||
global_num_tokens=self.global_num_tokens,
|
global_num_tokens=global_num_tokens,
|
||||||
gathered_buffer=gathered_buffer,
|
gathered_buffer=gathered_buffer,
|
||||||
)
|
)
|
||||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||||
@@ -348,8 +347,6 @@ class CudaGraphRunner:
|
|||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
if forward_batch.mrope_positions is not None:
|
if forward_batch.mrope_positions is not None:
|
||||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||||
if self.enable_dp_attention:
|
|
||||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
|||||||
@@ -174,17 +174,17 @@ class ServerArgs:
|
|||||||
self.cuda_graph_max_bs = 4
|
self.cuda_graph_max_bs = 4
|
||||||
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
|
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
|
||||||
|
|
||||||
|
# Choose kernel backends
|
||||||
if not is_flashinfer_available():
|
if not is_flashinfer_available():
|
||||||
self.attention_backend = "triton"
|
self.attention_backend = "triton"
|
||||||
self.sampling_backend = "pytorch"
|
self.sampling_backend = "pytorch"
|
||||||
|
|
||||||
# Default kernel backends
|
|
||||||
if self.attention_backend is None:
|
if self.attention_backend is None:
|
||||||
self.attention_backend = "flashinfer"
|
self.attention_backend = "flashinfer"
|
||||||
|
|
||||||
if self.sampling_backend is None:
|
if self.sampling_backend is None:
|
||||||
self.sampling_backend = "flashinfer"
|
self.sampling_backend = "flashinfer"
|
||||||
|
|
||||||
|
# Others
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.dp_size = self.tp_size
|
self.dp_size = self.tp_size
|
||||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||||
@@ -205,9 +205,6 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.disable_overlap_schedule = True
|
self.disable_overlap_schedule = True
|
||||||
|
|
||||||
if not self.disable_overlap_schedule:
|
|
||||||
self.disable_jump_forward = True
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
# Model and port args
|
# Model and port args
|
||||||
|
|||||||
@@ -2,3 +2,4 @@
|
|||||||
|
|
||||||
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
|
||||||
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
|
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
|
||||||
|
kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}')
|
||||||
|
|||||||
Reference in New Issue
Block a user