Enable overlap by default (#2067)

This commit is contained in:
Lianmin Zheng
2024-11-19 22:07:58 -08:00
committed by GitHub
parent 699384cb01
commit 7d671e4ad2
17 changed files with 92 additions and 75 deletions

View File

@@ -30,7 +30,7 @@ import torch
import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
@@ -102,7 +102,7 @@ class Scheduler:
self.disable_jump_forward = server_args.disable_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
@@ -159,6 +159,23 @@ class Scheduler:
trust_remote_code=server_args.trust_remote_code,
)
# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
if (
server_args.attention_backend == "triton"
or server_args.enable_double_sparsity
or (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
)
):
self.enable_overlap = False
logger.info(
"Overlap scheduler is disabled if using triton attention backend."
)
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
@@ -903,6 +920,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
@@ -958,6 +976,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
@@ -1031,6 +1050,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)