Enable overlap by default (#2067)
This commit is contained in:
@@ -30,7 +30,7 @@ import torch
|
||||
import zmq
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
@@ -102,7 +102,7 @@ class Scheduler:
|
||||
self.disable_jump_forward = server_args.disable_jump_forward
|
||||
self.lora_paths = server_args.lora_paths
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
self.enable_overlap = server_args.enable_overlap_schedule
|
||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
|
||||
@@ -159,6 +159,23 @@ class Scheduler:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
# Check whether overlap can be enabled
|
||||
if not self.is_generation:
|
||||
self.enable_overlap = False
|
||||
logger.info("Overlap scheduler is disabled for embedding models.")
|
||||
if (
|
||||
server_args.attention_backend == "triton"
|
||||
or server_args.enable_double_sparsity
|
||||
or (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
)
|
||||
):
|
||||
self.enable_overlap = False
|
||||
logger.info(
|
||||
"Overlap scheduler is disabled if using triton attention backend."
|
||||
)
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.enable_overlap:
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
@@ -903,6 +920,7 @@ class Scheduler:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
@@ -958,6 +976,7 @@ class Scheduler:
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
else: # embedding or reward model
|
||||
@@ -1031,6 +1050,7 @@ class Scheduler:
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
|
||||
Reference in New Issue
Block a user