diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e25af5583..2c78e70bf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1002,7 +1002,7 @@ class Scheduler: if req.is_retracted: continue - if self.server_args.enable_overlap_schedule and (req.finished()): + if self.enable_overlap and req.finished(): self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue @@ -1319,7 +1319,7 @@ def run_scheduler_process( try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) pipe_writer.send("ready") - if server_args.enable_overlap_schedule: + if scheduler.enable_overlap: scheduler.event_loop_overlap() else: scheduler.event_loop_normal() diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 412680f98..f5e43f348 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -26,7 +26,6 @@ import torch from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) @@ -176,16 +175,8 @@ class TpModelWorkerClient: ) % self.future_token_ids_limit return None, future_next_token_ids - def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - embeddings = logits_output.embeddings - return embeddings - def update_weights(self, recv_req: UpdateWeightReqInput): - success, message = self.model_runner.update_weights( - recv_req.model_path, recv_req.load_format - ) + success, message = self.worker.update_weights(recv_req) return success, message def __delete__(self): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8096fec5a..ea09b3c26 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -276,10 +276,6 @@ class ModelRunner: else None ) self.dtype = self.vllm_model_config.dtype - if self.sliding_window_size: - assert ( - self.server_args.attention_backend == "flashinfer" - ), "Only flashinfer supports window attention." logger.info( f"Load weight end. " diff --git a/test/srt/test_triton_attention_backend.py b/test/srt/test_triton_attention_backend.py index 2687bb435..5a8187de5 100644 --- a/test/srt/test_triton_attention_backend.py +++ b/test/srt/test_triton_attention_backend.py @@ -1,4 +1,8 @@ -import subprocess +""" +Usage: +python3 -m unittest test_triton_attention_backend.TestTritonAttnBackend.test_mmlu +""" + import unittest from types import SimpleNamespace