From cdcbde5fc3155edaa6b98a13ab8764101e657b23 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 29 Jul 2024 23:04:48 -0700 Subject: [PATCH] Code structure refactor (#807) --- docs/en/hyperparameter_tuning.md | 6 +- python/sglang/__init__.py | 63 ++++++++++--------- python/sglang/bench_latency.py | 4 +- python/sglang/srt/layers/logits_processor.py | 2 +- python/sglang/srt/layers/radix_attention.py | 2 +- python/sglang/srt/layers/token_attention.py | 2 +- .../manager_multi.py => controller_multi.py} | 2 +- ...manager_single.py => controller_single.py} | 2 +- .../srt/managers/detokenizer_manager.py | 2 +- python/sglang/srt/managers/io_struct.py | 2 +- ...edule_heuristic.py => policy_scheduler.py} | 24 +++---- .../infer_batch.py => schedule_batch.py} | 4 +- .../managers/{controller => }/tp_worker.py | 26 ++++---- .../sglang/srt/{ => mem_cache}/flush_cache.py | 2 +- .../sglang/srt/{ => mem_cache}/memory_pool.py | 0 .../controller => mem_cache}/radix_cache.py | 0 .../cuda_graph_runner.py | 2 +- .../model_runner.py | 6 +- python/sglang/srt/models/chatglm.py | 2 +- python/sglang/srt/models/commandr.py | 2 +- python/sglang/srt/models/dbrx.py | 2 +- python/sglang/srt/models/deepseek.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/models/gemma.py | 2 +- python/sglang/srt/models/gemma2.py | 2 +- python/sglang/srt/models/gpt_bigcode.py | 2 +- python/sglang/srt/models/grok.py | 2 +- python/sglang/srt/models/internlm2.py | 2 +- python/sglang/srt/models/llama2.py | 2 +- .../sglang/srt/models/llama_classification.py | 2 +- python/sglang/srt/models/llava.py | 4 +- python/sglang/srt/models/llavavid.py | 4 +- python/sglang/srt/models/minicpm.py | 2 +- python/sglang/srt/models/mixtral.py | 2 +- python/sglang/srt/models/mixtral_quant.py | 2 +- python/sglang/srt/models/qwen.py | 2 +- python/sglang/srt/models/qwen2.py | 2 +- python/sglang/srt/models/qwen2_moe.py | 2 +- python/sglang/srt/models/stablelm.py | 2 +- python/sglang/srt/server.py | 6 +- python/sglang/srt/server_args.py | 8 +-- 41 files changed, 106 insertions(+), 105 deletions(-) rename python/sglang/srt/managers/{controller/manager_multi.py => controller_multi.py} (99%) rename python/sglang/srt/managers/{controller/manager_single.py => controller_single.py} (98%) rename python/sglang/srt/managers/{controller/schedule_heuristic.py => policy_scheduler.py} (82%) rename python/sglang/srt/managers/{controller/infer_batch.py => schedule_batch.py} (99%) rename python/sglang/srt/managers/{controller => }/tp_worker.py (98%) rename python/sglang/srt/{ => mem_cache}/flush_cache.py (92%) rename python/sglang/srt/{ => mem_cache}/memory_pool.py (100%) rename python/sglang/srt/{managers/controller => mem_cache}/radix_cache.py (100%) rename python/sglang/srt/{managers/controller => model_executor}/cuda_graph_runner.py (99%) rename python/sglang/srt/{managers/controller => model_executor}/model_runner.py (98%) diff --git a/docs/en/hyperparameter_tuning.md b/docs/en/hyperparameter_tuning.md index 85315e745..2ea43e26a 100644 --- a/docs/en/hyperparameter_tuning.md +++ b/docs/en/hyperparameter_tuning.md @@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`. If OOM happens during decoding, try to decrease `--max-running-requests`. You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding. -### (Minor) Tune `--schedule-heuristic` -If you have many shared prefixes, use the default `--schedule-heuristic lpm`. `lpm` stands for longest prefix match. +### (Minor) Tune `--schedule-policy` +If you have many shared prefixes, use the default `--schedule-policy lpm`. `lpm` stands for longest prefix match. When you have no shared prefixes at all or you always send the requests with the shared prefixes together, -you can try `--schedule-heuristic fcfs`. `fcfs` stands for first come first serve. +you can try `--schedule-policy fcfs`. `fcfs` stands for first come first serve. diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 413ab9e7c..f4eec131e 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -1,4 +1,5 @@ # SGL API Components + from sglang.api import ( Runtime, assistant, @@ -22,46 +23,46 @@ from sglang.api import ( video, ) +# SGLang DSL APIs +__all__ = [ + "Runtime", + "assistant", + "assistant_begin", + "assistant_end", + "flush_cache", + "function", + "gen", + "gen_int", + "gen_string", + "get_server_args", + "image", + "select", + "set_default_backend", + "system", + "system_begin", + "system_end", + "user", + "user_begin", + "user_end", + "video", +] + # Global Configurations from sglang.global_config import global_config +__all__ += ["global_config"] + +from sglang.version import __version__ + +__all__ += ["__version__"] + # SGL Backends from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.utils import LazyImport -from sglang.version import __version__ Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") - -# public APIs management -__all__ = [ - "global_config", - "Anthropic", - "LiteLLM", - "OpenAI", - "RuntimeEndpoint", - "VertexAI", - "function", - "Runtime", - "set_default_backend", - "flush_cache", - "get_server_args", - "gen", - "gen_int", - "gen_string", - "image", - "video", - "select", - "system", - "user", - "assistant", - "user_begin", - "user_end", - "assistant_begin", - "assistant_end", - "system_begin", - "system_end", -] +__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"] diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index c2eb93a24..c4ffce634 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -37,9 +37,9 @@ import torch import torch.distributed as dist from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req -from sglang.srt.managers.controller.model_runner import ModelRunner +from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req from sglang.srt.model_config import ModelConfig +from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs from sglang.srt.utils import suppress_other_loggers diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index d3aa2469a..200071c60 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -25,7 +25,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) -from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata +from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata @dataclasses.dataclass diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index fb95106be..ab3a65029 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -22,7 +22,7 @@ from torch import nn from sglang.global_config import global_config from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd -from sglang.srt.managers.controller.model_runner import ( +from sglang.srt.model_executor.model_runner import ( ForwardMode, InputMetadata, global_server_args_dict, diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 565e1359f..a792b7f3a 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -20,7 +20,7 @@ import torch import triton import triton.language as tl -from sglang.srt.managers.controller.infer_batch import global_server_args_dict +from sglang.srt.managers.schedule_batch import global_server_args_dict if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 diff --git a/python/sglang/srt/managers/controller/manager_multi.py b/python/sglang/srt/managers/controller_multi.py similarity index 99% rename from python/sglang/srt/managers/controller/manager_multi.py rename to python/sglang/srt/managers/controller_multi.py index 08c9db82b..dcd984e0f 100644 --- a/python/sglang/srt/managers/controller/manager_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -27,7 +27,7 @@ from enum import Enum, auto import numpy as np import zmq -from sglang.srt.managers.controller.manager_single import ( +from sglang.srt.managers.controller_single import ( start_controller_process as start_controller_process_single, ) from sglang.srt.managers.io_struct import ( diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller_single.py similarity index 98% rename from python/sglang/srt/managers/controller/manager_single.py rename to python/sglang/srt/managers/controller_single.py index 012d6c1d6..415325b13 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -22,7 +22,7 @@ from typing import List import zmq -from sglang.srt.managers.controller.tp_worker import ( +from sglang.srt.managers.tp_worker import ( ModelTpServer, broadcast_recv_input, launch_tp_servers, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index b8607482e..0bd03d314 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -25,8 +25,8 @@ import zmq import zmq.asyncio from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut +from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index f0b927a69..036837a37 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,7 +22,7 @@ import uuid from dataclasses import dataclass from typing import Dict, List, Optional, Union -from sglang.srt.managers.controller.infer_batch import BaseFinishReason +from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling_params import SamplingParams diff --git a/python/sglang/srt/managers/controller/schedule_heuristic.py b/python/sglang/srt/managers/policy_scheduler.py similarity index 82% rename from python/sglang/srt/managers/controller/schedule_heuristic.py rename to python/sglang/srt/managers/policy_scheduler.py index d1f45836b..0eecc41d8 100644 --- a/python/sglang/srt/managers/controller/schedule_heuristic.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -13,47 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""Request scheduler heuristic.""" +"""Request policy scheduler""" import random from collections import defaultdict -class ScheduleHeuristic: +class PolicyScheduler: def __init__( self, - schedule_heuristic, + policy, max_running_seqs, max_prefill_num_tokens, max_total_num_tokens, tree_cache, ): - if tree_cache.disable and schedule_heuristic == "lpm": + if tree_cache.disable and policy == "lpm": # LMP is meaningless when the tree cache is disabled. - schedule_heuristic = "fcfs" + policy = "fcfs" - self.schedule_heuristic = schedule_heuristic + self.policy = policy self.max_running_seqs = max_running_seqs self.max_prefill_num_tokens = max_prefill_num_tokens self.max_total_num_tokens = max_total_num_tokens self.tree_cache = tree_cache def get_priority_queue(self, waiting_queue): - if self.schedule_heuristic == "lpm": + if self.policy == "lpm": # longest prefix match waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) return waiting_queue - elif self.schedule_heuristic == "fcfs": + elif self.policy == "fcfs": # first come first serve return waiting_queue - elif self.schedule_heuristic == "lof": + elif self.policy == "lof": # longest output first waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) return waiting_queue - elif self.schedule_heuristic == "random": + elif self.policy == "random": random.shuffle(waiting_queue) return waiting_queue - elif self.schedule_heuristic == "dfs-weight": + elif self.policy == "dfs-weight": last_node_to_reqs = defaultdict(list) for req in waiting_queue: last_node_to_reqs[req.last_node].append(req) @@ -70,7 +70,7 @@ class ScheduleHeuristic: assert len(q) == len(waiting_queue) return q else: - raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") + raise ValueError(f"Unknown schedule_policy: {self.policy}") def calc_weight(self, cur_node, node_to_weight): for child in cur_node.children.values(): diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/schedule_batch.py similarity index 99% rename from python/sglang/srt/managers/controller/infer_batch.py rename to python/sglang/srt/managers/schedule_batch.py index a80a9d657..6cfd2f650 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -28,8 +28,8 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap -from sglang.srt.managers.controller.radix_cache import RadixCache -from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.mem_cache.radix_cache import RadixCache INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/tp_worker.py similarity index 98% rename from python/sglang/srt/managers/controller/tp_worker.py rename to python/sglang/srt/managers/tp_worker.py index a688c53e3..d21a0c694 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -29,23 +29,23 @@ from sglang.global_config import global_config from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.managers.controller.infer_batch import ( - FINISH_ABORT, - BaseFinishReason, - Batch, - ForwardMode, - Req, -) -from sglang.srt.managers.controller.model_runner import ModelRunner -from sglang.srt.managers.controller.radix_cache import RadixCache -from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic from sglang.srt.managers.io_struct import ( AbortReq, BatchTokenIDOut, FlushCacheReq, TokenizedGenerateReqInput, ) +from sglang.srt.managers.policy_scheduler import PolicyScheduler +from sglang.srt.managers.schedule_batch import ( + FINISH_ABORT, + BaseFinishReason, + Batch, + ForwardMode, + Req, +) +from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_config import ModelConfig +from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_int_token_logit_bias, @@ -74,7 +74,7 @@ class ModelTpServer: self.tp_rank = tp_rank self.tp_size = server_args.tp_size self.dp_size = server_args.dp_size - self.schedule_heuristic = server_args.schedule_heuristic + self.schedule_policy = server_args.schedule_policy self.disable_regex_jump_forward = server_args.disable_regex_jump_forward # Chunked prefill @@ -150,8 +150,8 @@ class ModelTpServer: disable=server_args.disable_radix_cache, ) self.tree_cache_metrics = {"total": 0, "hit": 0} - self.scheduler = ScheduleHeuristic( - self.schedule_heuristic, + self.scheduler = PolicyScheduler( + self.schedule_policy, self.max_running_requests, self.max_prefill_tokens, self.max_total_num_tokens, diff --git a/python/sglang/srt/flush_cache.py b/python/sglang/srt/mem_cache/flush_cache.py similarity index 92% rename from python/sglang/srt/flush_cache.py rename to python/sglang/srt/mem_cache/flush_cache.py index 4ef3ab1d3..3ac425ac8 100644 --- a/python/sglang/srt/flush_cache.py +++ b/python/sglang/srt/mem_cache/flush_cache.py @@ -17,7 +17,7 @@ limitations under the License. Flush the KV cache. Usage: -python3 -m sglang.srt.flush_cache --url http://localhost:30000 +python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000 """ import argparse diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py similarity index 100% rename from python/sglang/srt/memory_pool.py rename to python/sglang/srt/mem_cache/memory_pool.py diff --git a/python/sglang/srt/managers/controller/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py similarity index 100% rename from python/sglang/srt/managers/controller/radix_cache.py rename to python/sglang/srt/mem_cache/radix_cache.py diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py similarity index 99% rename from python/sglang/srt/managers/controller/cuda_graph_runner.py rename to python/sglang/srt/model_executor/cuda_graph_runner.py index 7d59eeef5..458395e73 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -29,7 +29,7 @@ from sglang.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, ) -from sglang.srt.managers.controller.infer_batch import ( +from sglang.srt.managers.schedule_batch import ( Batch, ForwardMode, InputMetadata, diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/model_executor/model_runner.py similarity index 98% rename from python/sglang/srt/managers/controller/model_runner.py rename to python/sglang/srt/model_executor/model_runner.py index 24c59b6ff..10b1b40de 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -40,13 +40,13 @@ from vllm.distributed import ( from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config -from sglang.srt.managers.controller.infer_batch import ( +from sglang.srt.managers.schedule_batch import ( Batch, ForwardMode, InputMetadata, global_server_args_dict, ) -from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, @@ -273,7 +273,7 @@ class ModelRunner: ) def init_cuda_graphs(self): - from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: self.cuda_graph_runner = None diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 9df6e4fd3..4589a14ac 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata LoraConfig = None diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index cc4ce9d4a..671746bf7 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata @torch.compile diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 3104ca7c3..1d0f40bd3 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class DbrxRouter(nn.Module): diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index c12b8a09c..09481e71b 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.infer_batch import InputMetadata +from sglang.srt.managers.schedule_batch import InputMetadata class DeepseekMLP(nn.Module): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ae3d06ed0..4cc37c388 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class DeepseekV2MLP(nn.Module): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index aa42ad508..843bc5d28 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class GemmaMLP(nn.Module): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 02f20e705..4c77e0c69 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class GemmaRMSNorm(CustomOp): diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 0ac89f648..eee7f6483 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.infer_batch import InputMetadata +from sglang.srt.managers.schedule_batch import InputMetadata class GPTBigCodeAttention(nn.Module): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 9c4251b09..b989c4e79 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -52,7 +52,7 @@ from vllm.utils import print_warning_once from sglang.srt.layers.fused_moe import fused_moe from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata use_fused = True diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index bf6d99e3c..35f81f8a9 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class InternLM2MLP(nn.Module): diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 2287cf0a1..3e24e7b9c 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata MergedColumnParallelLinear = None QKVParallelLinear = None diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index f96eae093..3ffb256dd 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitProcessorOutput -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.models.llama2 import LlamaModel diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 2fcc4e998..f89a9b618 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -32,13 +32,13 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.controller.infer_batch import ForwardMode -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.managers.schedule_batch import ForwardMode from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, unpad_image_shape, ) +from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 1f08137c6..3f88d41a1 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -26,13 +26,13 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.controller.infer_batch import ForwardMode -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.managers.schedule_batch import ForwardMode from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, unpad_image_shape, ) +from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.models.llama2 import LlamaForCausalLM diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 7a07335d1..ab2a08325 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class MiniCPMMLP(nn.Module): diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 0cfbad719..a7d45d455 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -50,7 +50,7 @@ from vllm.utils import print_warning_once from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class MixtralMoE(nn.Module): diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index fce04cc89..d643db33f 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class MixtralMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index cf6b264f3..52edd28bc 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class QWenMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 80ab61b64..2df91814e 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata Qwen2Config = None diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 213ba6d3c..7475d8f62 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class Qwen2MoeMLP(nn.Module): diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 4589c997c..76f40437a 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.controller.model_runner import InputMetadata +from sglang.srt.model_executor.model_runner import InputMetadata class StablelmMLP(nn.Module): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index f1b5dae9c..4c8ace962 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -44,11 +44,11 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.constrained import disable_cache from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.controller.manager_multi import ( +from sglang.srt.managers.controller_multi import ( start_controller_process as start_controller_process_multi, ) -from sglang.srt.managers.controller.manager_single import launch_tp_servers -from sglang.srt.managers.controller.manager_single import ( +from sglang.srt.managers.controller_single import launch_tp_servers +from sglang.srt.managers.controller_single import ( start_controller_process as start_controller_process_single, ) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8b3de98e2..e62987dd9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -44,7 +44,7 @@ class ServerArgs: max_prefill_tokens: Optional[int] = None max_running_requests: Optional[int] = None max_num_reqs: Optional[int] = None - schedule_heuristic: str = "lpm" + schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 # Other runtime options @@ -231,11 +231,11 @@ class ServerArgs: help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", ) parser.add_argument( - "--schedule-heuristic", + "--schedule-policy", type=str, - default=ServerArgs.schedule_heuristic, + default=ServerArgs.schedule_policy, choices=["lpm", "random", "fcfs", "dfs-weight"], - help="The scheduling heuristic.", + help="The scheduling policy of the requests.", ) parser.add_argument( "--schedule-conservativeness",