Code structure refactor (#807)

This commit is contained in:
Liangsheng Yin
2024-07-29 23:04:48 -07:00
committed by GitHub
parent 21e22b9e96
commit cdcbde5fc3
41 changed files with 106 additions and 105 deletions

View File

@@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`.
If OOM happens during decoding, try to decrease `--max-running-requests`. If OOM happens during decoding, try to decrease `--max-running-requests`.
You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding. You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
### (Minor) Tune `--schedule-heuristic` ### (Minor) Tune `--schedule-policy`
If you have many shared prefixes, use the default `--schedule-heuristic lpm`. `lpm` stands for longest prefix match. If you have many shared prefixes, use the default `--schedule-policy lpm`. `lpm` stands for longest prefix match.
When you have no shared prefixes at all or you always send the requests with the shared prefixes together, When you have no shared prefixes at all or you always send the requests with the shared prefixes together,
you can try `--schedule-heuristic fcfs`. `fcfs` stands for first come first serve. you can try `--schedule-policy fcfs`. `fcfs` stands for first come first serve.

View File

@@ -1,4 +1,5 @@
# SGL API Components # SGL API Components
from sglang.api import ( from sglang.api import (
Runtime, Runtime,
assistant, assistant,
@@ -22,46 +23,46 @@ from sglang.api import (
video, video,
) )
# SGLang DSL APIs
__all__ = [
"Runtime",
"assistant",
"assistant_begin",
"assistant_end",
"flush_cache",
"function",
"gen",
"gen_int",
"gen_string",
"get_server_args",
"image",
"select",
"set_default_backend",
"system",
"system_begin",
"system_end",
"user",
"user_begin",
"user_end",
"video",
]
# Global Configurations # Global Configurations
from sglang.global_config import global_config from sglang.global_config import global_config
__all__ += ["global_config"]
from sglang.version import __version__
__all__ += ["__version__"]
# SGL Backends # SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport from sglang.utils import LazyImport
from sglang.version import __version__
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"]
# public APIs management
__all__ = [
"global_config",
"Anthropic",
"LiteLLM",
"OpenAI",
"RuntimeEndpoint",
"VertexAI",
"function",
"Runtime",
"set_default_backend",
"flush_cache",
"get_server_args",
"gen",
"gen_int",
"gen_string",
"image",
"video",
"select",
"system",
"user",
"assistant",
"user_begin",
"user_end",
"assistant_begin",
"assistant_end",
"system_begin",
"system_end",
]

View File

@@ -37,9 +37,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers from sglang.srt.utils import suppress_other_loggers

View File

@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
@dataclasses.dataclass @dataclasses.dataclass

View File

@@ -22,7 +22,7 @@ from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.model_runner import ( from sglang.srt.model_executor.model_runner import (
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
global_server_args_dict, global_server_args_dict,

View File

@@ -20,7 +20,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.controller.infer_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False): if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32

View File

@@ -27,7 +27,7 @@ from enum import Enum, auto
import numpy as np import numpy as np
import zmq import zmq
from sglang.srt.managers.controller.manager_single import ( from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single, start_controller_process as start_controller_process_single,
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (

View File

@@ -22,7 +22,7 @@ from typing import List
import zmq import zmq
from sglang.srt.managers.controller.tp_worker import ( from sglang.srt.managers.tp_worker import (
ModelTpServer, ModelTpServer,
broadcast_recv_input, broadcast_recv_input,
launch_tp_servers, launch_tp_servers,

View File

@@ -25,8 +25,8 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry

View File

@@ -22,7 +22,7 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.managers.controller.infer_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams

View File

@@ -13,47 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Request scheduler heuristic.""" """Request policy scheduler"""
import random import random
from collections import defaultdict from collections import defaultdict
class ScheduleHeuristic: class PolicyScheduler:
def __init__( def __init__(
self, self,
schedule_heuristic, policy,
max_running_seqs, max_running_seqs,
max_prefill_num_tokens, max_prefill_num_tokens,
max_total_num_tokens, max_total_num_tokens,
tree_cache, tree_cache,
): ):
if tree_cache.disable and schedule_heuristic == "lpm": if tree_cache.disable and policy == "lpm":
# LMP is meaningless when the tree cache is disabled. # LMP is meaningless when the tree cache is disabled.
schedule_heuristic = "fcfs" policy = "fcfs"
self.schedule_heuristic = schedule_heuristic self.policy = policy
self.max_running_seqs = max_running_seqs self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache self.tree_cache = tree_cache
def get_priority_queue(self, waiting_queue): def get_priority_queue(self, waiting_queue):
if self.schedule_heuristic == "lpm": if self.policy == "lpm":
# longest prefix match # longest prefix match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue return waiting_queue
elif self.schedule_heuristic == "fcfs": elif self.policy == "fcfs":
# first come first serve # first come first serve
return waiting_queue return waiting_queue
elif self.schedule_heuristic == "lof": elif self.policy == "lof":
# longest output first # longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue return waiting_queue
elif self.schedule_heuristic == "random": elif self.policy == "random":
random.shuffle(waiting_queue) random.shuffle(waiting_queue)
return waiting_queue return waiting_queue
elif self.schedule_heuristic == "dfs-weight": elif self.policy == "dfs-weight":
last_node_to_reqs = defaultdict(list) last_node_to_reqs = defaultdict(list)
for req in waiting_queue: for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req) last_node_to_reqs[req.last_node].append(req)
@@ -70,7 +70,7 @@ class ScheduleHeuristic:
assert len(q) == len(waiting_queue) assert len(q) == len(waiting_queue)
return q return q
else: else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") raise ValueError(f"Unknown schedule_policy: {self.policy}")
def calc_weight(self, cur_node, node_to_weight): def calc_weight(self, cur_node, node_to_weight):
for child in cur_node.children.values(): for child in cur_node.children.values():

View File

@@ -28,8 +28,8 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

View File

@@ -29,23 +29,23 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchTokenIDOut, BatchTokenIDOut,
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.policy_scheduler import PolicyScheduler
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_token_logit_bias, get_int_token_logit_bias,
@@ -74,7 +74,7 @@ class ModelTpServer:
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_policy = server_args.schedule_policy
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Chunked prefill # Chunked prefill
@@ -150,8 +150,8 @@ class ModelTpServer:
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = ScheduleHeuristic( self.scheduler = PolicyScheduler(
self.schedule_heuristic, self.schedule_policy,
self.max_running_requests, self.max_running_requests,
self.max_prefill_tokens, self.max_prefill_tokens,
self.max_total_num_tokens, self.max_total_num_tokens,

View File

@@ -17,7 +17,7 @@ limitations under the License.
Flush the KV cache. Flush the KV cache.
Usage: Usage:
python3 -m sglang.srt.flush_cache --url http://localhost:30000 python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
""" """
import argparse import argparse

View File

@@ -29,7 +29,7 @@ from sglang.srt.layers.logits_processor import (
LogitsMetadata, LogitsMetadata,
LogitsProcessor, LogitsProcessor,
) )
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.schedule_batch import (
Batch, Batch,
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,

View File

@@ -40,13 +40,13 @@ from vllm.distributed import (
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.schedule_batch import (
Batch, Batch,
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
@@ -273,7 +273,7 @@ class ModelRunner:
) )
def init_cuda_graphs(self): def init_cuda_graphs(self):
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
self.cuda_graph_runner = None self.cuda_graph_runner = None

View File

@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
LoraConfig = None LoraConfig = None

View File

@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
@torch.compile @torch.compile

View File

@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class DbrxRouter(nn.Module): class DbrxRouter(nn.Module):

View File

@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.infer_batch import InputMetadata from sglang.srt.managers.schedule_batch import InputMetadata
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):

View File

@@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):

View File

@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):

View File

@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):

View File

@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.infer_batch import InputMetadata from sglang.srt.managers.schedule_batch import InputMetadata
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):

View File

@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.fused_moe import fused_moe from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
use_fused = True use_fused = True

View File

@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):

View File

@@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
MergedColumnParallelLinear = None MergedColumnParallelLinear = None
QKVParallelLinear = None QKVParallelLinear = None

View File

@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaModel from sglang.srt.models.llama2 import LlamaModel

View File

@@ -32,13 +32,13 @@ from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.controller.infer_batch import ForwardMode from sglang.srt.managers.schedule_batch import ForwardMode
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.mm_utils import ( from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
unpad_image, unpad_image,
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM

View File

@@ -26,13 +26,13 @@ from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.controller.infer_batch import ForwardMode from sglang.srt.managers.schedule_batch import ForwardMode
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.mm_utils import ( from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
unpad_image, unpad_image,
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM

View File

@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class MiniCPMMLP(nn.Module): class MiniCPMMLP(nn.Module):

View File

@@ -50,7 +50,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):

View File

@@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):

View File

@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class QWenMLP(nn.Module): class QWenMLP(nn.Module):

View File

@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
Qwen2Config = None Qwen2Config = None

View File

@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):

View File

@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):

View File

@@ -44,11 +44,11 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.manager_multi import ( from sglang.srt.managers.controller_multi import (
start_controller_process as start_controller_process_multi, start_controller_process as start_controller_process_multi,
) )
from sglang.srt.managers.controller.manager_single import launch_tp_servers from sglang.srt.managers.controller_single import launch_tp_servers
from sglang.srt.managers.controller.manager_single import ( from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single, start_controller_process as start_controller_process_single,
) )
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process

View File

@@ -44,7 +44,7 @@ class ServerArgs:
max_prefill_tokens: Optional[int] = None max_prefill_tokens: Optional[int] = None
max_running_requests: Optional[int] = None max_running_requests: Optional[int] = None
max_num_reqs: Optional[int] = None max_num_reqs: Optional[int] = None
schedule_heuristic: str = "lpm" schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
# Other runtime options # Other runtime options
@@ -231,11 +231,11 @@ class ServerArgs:
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
) )
parser.add_argument( parser.add_argument(
"--schedule-heuristic", "--schedule-policy",
type=str, type=str,
default=ServerArgs.schedule_heuristic, default=ServerArgs.schedule_policy,
choices=["lpm", "random", "fcfs", "dfs-weight"], choices=["lpm", "random", "fcfs", "dfs-weight"],
help="The scheduling heuristic.", help="The scheduling policy of the requests.",
) )
parser.add_argument( parser.add_argument(
"--schedule-conservativeness", "--schedule-conservativeness",