Separate two entry points: Engine and HTTP server (#2996)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
@@ -91,7 +91,7 @@ Here is how you can do it:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from sglang.srt.models.registry import ModelRegistry
|
from sglang.srt.models.registry import ModelRegistry
|
||||||
from sglang.srt.server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
|
|
||||||
# for a single model, you can add it to the registry
|
# for a single model, you can add it to the registry
|
||||||
ModelRegistry.models[model_name] = model_class
|
ModelRegistry.models[model_name] = model_class
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def Runtime(*args, **kwargs):
|
|||||||
|
|
||||||
def Engine(*args, **kwargs):
|
def Engine(*args, **kwargs):
|
||||||
# Avoid importing unnecessary dependency
|
# Avoid importing unnecessary dependency
|
||||||
from sglang.srt.server import Engine
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
|
|
||||||
return Engine(*args, **kwargs)
|
return Engine(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from sglang.bench_serving import (
|
|||||||
set_ulimit,
|
set_ulimit,
|
||||||
)
|
)
|
||||||
from sglang.lang.backend.runtime_endpoint import Runtime
|
from sglang.lang.backend.runtime_endpoint import Runtime
|
||||||
from sglang.srt.server import Engine
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -57,12 +57,12 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server import _set_envs_and_config
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
|
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from typing import Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from sglang.srt.server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
|
||||||
|
|||||||
@@ -351,7 +351,7 @@ class Runtime:
|
|||||||
"""See the arguments in server_args.py::ServerArgs"""
|
"""See the arguments in server_args.py::ServerArgs"""
|
||||||
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
|
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
|
||||||
# client code without installing SRT server and its dependency if they want.
|
# client code without installing SRT server and its dependency if they want.
|
||||||
from sglang.srt.server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import is_port_available
|
from sglang.srt.utils import is_port_available
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from sglang.srt.server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
from sglang.srt.server_args import prepare_server_args
|
from sglang.srt.server_args import prepare_server_args
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
|
||||||
|
|||||||
449
python/sglang/srt/entrypoints/engine.py
Normal file
449
python/sglang/srt/entrypoints/engine.py
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
The entry point of inference server. (SRT = SGLang Runtime)
|
||||||
|
|
||||||
|
This file implements python APIs for the inference engine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
# Fix a bug of Python threading
|
||||||
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import uvloop
|
||||||
|
|
||||||
|
from sglang.srt.managers.data_parallel_controller import (
|
||||||
|
run_data_parallel_controller_process,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||||
|
from sglang.srt.managers.io_struct import (
|
||||||
|
EmbeddingReqInput,
|
||||||
|
GenerateReqInput,
|
||||||
|
GetWeightsByNameReqInput,
|
||||||
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
ReleaseMemoryOccupationReqInput,
|
||||||
|
ResumeMemoryOccupationReqInput,
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromTensorReqInput,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
MultiprocessingSerializer,
|
||||||
|
assert_pkg_version,
|
||||||
|
configure_logger,
|
||||||
|
kill_process_tree,
|
||||||
|
maybe_set_triton_cache_manager,
|
||||||
|
prepare_model_and_tokenizer,
|
||||||
|
set_prometheus_multiproc_dir,
|
||||||
|
set_ulimit,
|
||||||
|
)
|
||||||
|
from sglang.version import __version__
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
class Engine:
|
||||||
|
"""
|
||||||
|
The entry point to the inference engine.
|
||||||
|
|
||||||
|
- The engine consists of three components:
|
||||||
|
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
||||||
|
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
||||||
|
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
||||||
|
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""
|
||||||
|
The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
|
||||||
|
Please refer to `ServerArgs` for the documentation.
|
||||||
|
"""
|
||||||
|
if "server_args" in kwargs:
|
||||||
|
# Directly load server_args
|
||||||
|
server_args = kwargs["server_args"]
|
||||||
|
else:
|
||||||
|
# Construct server_args from kwargs
|
||||||
|
if "log_level" not in kwargs:
|
||||||
|
# Do not print logs by default
|
||||||
|
kwargs["log_level"] = "error"
|
||||||
|
server_args = ServerArgs(**kwargs)
|
||||||
|
|
||||||
|
# Shutdown the subprocesses automatically when the program exists
|
||||||
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
|
# Launch subprocesses
|
||||||
|
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
||||||
|
server_args=server_args
|
||||||
|
)
|
||||||
|
self.tokenizer_manager = tokenizer_manager
|
||||||
|
self.scheduler_info = scheduler_info
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
|
prompt: Optional[Union[List[str], str]] = None,
|
||||||
|
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||||
|
# The token ids for text; one can either specify text or input_ids.
|
||||||
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[Dict, Iterator[Dict]]:
|
||||||
|
"""
|
||||||
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||||
|
Please refer to `GenerateReqInput` for the documentation.
|
||||||
|
"""
|
||||||
|
obj = GenerateReqInput(
|
||||||
|
text=prompt,
|
||||||
|
input_ids=input_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
lora_path=lora_path,
|
||||||
|
custom_logit_processor=custom_logit_processor,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
|
||||||
|
def generator_wrapper():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = loop.run_until_complete(generator.__anext__())
|
||||||
|
yield chunk
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
return generator_wrapper()
|
||||||
|
else:
|
||||||
|
ret = loop.run_until_complete(generator.__anext__())
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def async_generate(
|
||||||
|
self,
|
||||||
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
|
prompt: Optional[Union[List[str], str]] = None,
|
||||||
|
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||||
|
# The token ids for text; one can either specify text or input_ids.
|
||||||
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[Dict, AsyncIterator[Dict]]:
|
||||||
|
"""
|
||||||
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||||
|
Please refer to `GenerateReqInput` for the documentation.
|
||||||
|
"""
|
||||||
|
obj = GenerateReqInput(
|
||||||
|
text=prompt,
|
||||||
|
input_ids=input_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
lora_path=lora_path,
|
||||||
|
stream=stream,
|
||||||
|
custom_logit_processor=custom_logit_processor,
|
||||||
|
)
|
||||||
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
return generator
|
||||||
|
else:
|
||||||
|
return await generator.__anext__()
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
||||||
|
Please refer to `EmbeddingReqInput` for the documentation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
obj = EmbeddingReqInput(text=prompt)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||||
|
ret = loop.run_until_complete(generator.__anext__())
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Shutdown the engine"""
|
||||||
|
kill_process_tree(os.getpid(), include_parent=False)
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
self.tokenizer_manager.start_profile()
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
self.tokenizer_manager.stop_profile()
|
||||||
|
|
||||||
|
def get_server_info(self):
|
||||||
|
return {
|
||||||
|
**dataclasses.asdict(self.tokenizer_manager.server_args), # server args
|
||||||
|
**self.scheduler_info,
|
||||||
|
"version": __version__,
|
||||||
|
}
|
||||||
|
|
||||||
|
def init_weights_update_group(
|
||||||
|
self,
|
||||||
|
master_address: str,
|
||||||
|
master_port: int,
|
||||||
|
rank_offset: int,
|
||||||
|
world_size: int,
|
||||||
|
group_name: str,
|
||||||
|
backend: str = "nccl",
|
||||||
|
):
|
||||||
|
"""Initialize parameter update group."""
|
||||||
|
obj = InitWeightsUpdateGroupReqInput(
|
||||||
|
master_address=master_address,
|
||||||
|
master_port=master_port,
|
||||||
|
rank_offset=rank_offset,
|
||||||
|
world_size=world_size,
|
||||||
|
group_name=group_name,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.init_weights_update_group(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_weights_from_distributed(self, name: str, dtype, shape):
|
||||||
|
"""Update weights from distributed source."""
|
||||||
|
obj = UpdateWeightsFromDistributedReqInput(
|
||||||
|
name=name,
|
||||||
|
dtype=dtype,
|
||||||
|
shape=shape,
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||||
|
"""Update weights from distributed source."""
|
||||||
|
obj = UpdateWeightsFromTensorReqInput(
|
||||||
|
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.update_weights_from_tensor(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_by_name(self, name: str, truncate_size: int = 100):
|
||||||
|
"""Get weights by parameter name."""
|
||||||
|
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def release_memory_occupation(self):
|
||||||
|
"""Release GPU occupation temporarily."""
|
||||||
|
obj = ReleaseMemoryOccupationReqInput()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.release_memory_occupation(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def resume_memory_occupation(self):
|
||||||
|
"""Resume GPU occupation."""
|
||||||
|
obj = ResumeMemoryOccupationReqInput()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
|
# Set global environments
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||||
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||||
|
|
||||||
|
# Set prometheus env vars
|
||||||
|
if server_args.enable_metrics:
|
||||||
|
set_prometheus_multiproc_dir()
|
||||||
|
|
||||||
|
# Set ulimit
|
||||||
|
set_ulimit()
|
||||||
|
|
||||||
|
# Fix triton bugs
|
||||||
|
if server_args.tp_size * server_args.dp_size > 1:
|
||||||
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||||
|
maybe_set_triton_cache_manager()
|
||||||
|
|
||||||
|
# Check flashinfer version
|
||||||
|
if server_args.attention_backend == "flashinfer":
|
||||||
|
assert_pkg_version(
|
||||||
|
"flashinfer",
|
||||||
|
"0.1.6",
|
||||||
|
"Please uninstall the old version and "
|
||||||
|
"reinstall the latest version by following the instructions "
|
||||||
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the signal handler.
|
||||||
|
# The child processes will send SIGQUIT to this process when any error happens
|
||||||
|
# This process then clean up the whole process tree
|
||||||
|
def sigquit_handler(signum, frame):
|
||||||
|
logger.error(
|
||||||
|
"Received sigquit from a child proces. It usually means the child failed."
|
||||||
|
)
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
|
signal.signal(signal.SIGQUIT, sigquit_handler)
|
||||||
|
|
||||||
|
# Set mp start method
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
|
||||||
|
"""
|
||||||
|
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||||
|
"""
|
||||||
|
# Configure global environment
|
||||||
|
configure_logger(server_args)
|
||||||
|
server_args.check_server_args()
|
||||||
|
_set_envs_and_config(server_args)
|
||||||
|
|
||||||
|
# Allocate ports for inter-process communications
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
logger.info(f"{server_args=}")
|
||||||
|
|
||||||
|
# If using model from www.modelscope.cn, first download the model.
|
||||||
|
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||||
|
server_args.model_path, server_args.tokenizer_path
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_procs = []
|
||||||
|
if server_args.dp_size == 1:
|
||||||
|
# Launch tensor parallel scheduler processes
|
||||||
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
|
enable=server_args.enable_memory_saver
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_pipe_readers = []
|
||||||
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||||
|
tp_rank_range = range(
|
||||||
|
tp_size_per_node * server_args.node_rank,
|
||||||
|
tp_size_per_node * (server_args.node_rank + 1),
|
||||||
|
)
|
||||||
|
for tp_rank in tp_rank_range:
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_scheduler_process,
|
||||||
|
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
||||||
|
)
|
||||||
|
with memory_saver_adapter.configure_subprocess():
|
||||||
|
proc.start()
|
||||||
|
scheduler_procs.append(proc)
|
||||||
|
scheduler_pipe_readers.append(reader)
|
||||||
|
else:
|
||||||
|
# Launch the data parallel controller
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
scheduler_pipe_readers = [reader]
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_data_parallel_controller_process,
|
||||||
|
args=(server_args, port_args, writer),
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
scheduler_procs.append(proc)
|
||||||
|
|
||||||
|
if server_args.node_rank >= 1:
|
||||||
|
# In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
|
||||||
|
# so they can just wait here.
|
||||||
|
|
||||||
|
for reader in scheduler_pipe_readers:
|
||||||
|
data = reader.recv()
|
||||||
|
assert data["status"] == "ready"
|
||||||
|
|
||||||
|
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
||||||
|
# When using `Engine` as a Python API, we don't want to block here.
|
||||||
|
return
|
||||||
|
|
||||||
|
for proc in scheduler_procs:
|
||||||
|
proc.join()
|
||||||
|
logger.error(
|
||||||
|
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Launch detokenizer process
|
||||||
|
detoken_proc = mp.Process(
|
||||||
|
target=run_detokenizer_process,
|
||||||
|
args=(
|
||||||
|
server_args,
|
||||||
|
port_args,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
detoken_proc.start()
|
||||||
|
|
||||||
|
# Launch tokenizer process
|
||||||
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||||
|
if server_args.chat_template:
|
||||||
|
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||||
|
|
||||||
|
# Wait for the model to finish loading
|
||||||
|
scheduler_infos = []
|
||||||
|
for i in range(len(scheduler_pipe_readers)):
|
||||||
|
try:
|
||||||
|
data = scheduler_pipe_readers[i].recv()
|
||||||
|
except EOFError:
|
||||||
|
logger.error(
|
||||||
|
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
||||||
|
)
|
||||||
|
scheduler_procs[i].join()
|
||||||
|
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if data["status"] != "ready":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Initialization failed. Please see the error messages above."
|
||||||
|
)
|
||||||
|
scheduler_infos.append(data)
|
||||||
|
|
||||||
|
# Assume all schedulers have the same scheduler_info
|
||||||
|
scheduler_info = scheduler_infos[0]
|
||||||
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
||||||
|
return tokenizer_manager, scheduler_info
|
||||||
579
python/sglang/srt/entrypoints/http_server.py
Normal file
579
python/sglang/srt/entrypoints/http_server.py
Normal file
@@ -0,0 +1,579 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
The entry point of inference server. (SRT = SGLang Runtime)
|
||||||
|
|
||||||
|
This file implements HTTP APIs for the inferenc engine via fastapi.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import multiprocessing as multiprocessing
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import AsyncIterator, Dict, Optional
|
||||||
|
|
||||||
|
# Fix a bug of Python threading
|
||||||
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import requests
|
||||||
|
import uvicorn
|
||||||
|
import uvloop
|
||||||
|
from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
|
from sglang.srt.managers.io_struct import (
|
||||||
|
CloseSessionReqInput,
|
||||||
|
ConfigureLoggingReq,
|
||||||
|
EmbeddingReqInput,
|
||||||
|
GenerateReqInput,
|
||||||
|
GetWeightsByNameReqInput,
|
||||||
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
OpenSessionReqInput,
|
||||||
|
ReleaseMemoryOccupationReqInput,
|
||||||
|
ResumeMemoryOccupationReqInput,
|
||||||
|
UpdateWeightFromDiskReqInput,
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
|
from sglang.srt.openai_api.adapter import (
|
||||||
|
v1_batches,
|
||||||
|
v1_cancel_batch,
|
||||||
|
v1_chat_completions,
|
||||||
|
v1_completions,
|
||||||
|
v1_delete_file,
|
||||||
|
v1_embeddings,
|
||||||
|
v1_files_create,
|
||||||
|
v1_retrieve_batch,
|
||||||
|
v1_retrieve_file,
|
||||||
|
v1_retrieve_file_content,
|
||||||
|
)
|
||||||
|
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
add_api_key_middleware,
|
||||||
|
add_prometheus_middleware,
|
||||||
|
delete_directory,
|
||||||
|
kill_process_tree,
|
||||||
|
set_uvicorn_logging_configs,
|
||||||
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
from sglang.version import __version__
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
# Fast API
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Store global states
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class _GlobalState:
|
||||||
|
tokenizer_manager: TokenizerManager
|
||||||
|
scheduler_info: Dict
|
||||||
|
|
||||||
|
|
||||||
|
_global_state: Optional[_GlobalState] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_state(global_state: _GlobalState):
|
||||||
|
global _global_state
|
||||||
|
_global_state = global_state
|
||||||
|
|
||||||
|
|
||||||
|
##### Native API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> Response:
|
||||||
|
"""Check the health of the http server."""
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health_generate")
|
||||||
|
async def health_generate(request: Request) -> Response:
|
||||||
|
"""Check the health of the inference server by generating one token."""
|
||||||
|
|
||||||
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
||||||
|
|
||||||
|
if _global_state.tokenizer_manager.is_generation:
|
||||||
|
gri = GenerateReqInput(
|
||||||
|
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gri = EmbeddingReqInput(
|
||||||
|
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||||
|
break
|
||||||
|
return Response(status_code=200)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
return Response(status_code=503)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_model_info")
|
||||||
|
async def get_model_info():
|
||||||
|
"""Get the model information."""
|
||||||
|
result = {
|
||||||
|
"model_path": _global_state.tokenizer_manager.model_path,
|
||||||
|
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
||||||
|
"is_generation": _global_state.tokenizer_manager.is_generation,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_server_info")
|
||||||
|
async def get_server_info():
|
||||||
|
return {
|
||||||
|
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
||||||
|
**_global_state.scheduler_info,
|
||||||
|
"version": __version__,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||||
|
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||||
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
|
"""Handle a generate request."""
|
||||||
|
if obj.stream:
|
||||||
|
|
||||||
|
async def stream_results() -> AsyncIterator[bytes]:
|
||||||
|
try:
|
||||||
|
async for out in _global_state.tokenizer_manager.generate_request(
|
||||||
|
obj, request
|
||||||
|
):
|
||||||
|
yield b"data: " + orjson.dumps(
|
||||||
|
out, option=orjson.OPT_NON_STR_KEYS
|
||||||
|
) + b"\n\n"
|
||||||
|
except ValueError as e:
|
||||||
|
out = {"error": {"message": str(e)}}
|
||||||
|
yield b"data: " + orjson.dumps(
|
||||||
|
out, option=orjson.OPT_NON_STR_KEYS
|
||||||
|
) + b"\n\n"
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_results(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
background=_global_state.tokenizer_manager.create_abort_task(obj),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
ret = await _global_state.tokenizer_manager.generate_request(
|
||||||
|
obj, request
|
||||||
|
).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||||
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||||
|
"""Handle an embedding request."""
|
||||||
|
try:
|
||||||
|
ret = await _global_state.tokenizer_manager.generate_request(
|
||||||
|
obj, request
|
||||||
|
).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||||
|
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||||
|
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||||
|
try:
|
||||||
|
ret = await _global_state.tokenizer_manager.generate_request(
|
||||||
|
obj, request
|
||||||
|
).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/flush_cache")
|
||||||
|
async def flush_cache():
|
||||||
|
"""Flush the radix cache."""
|
||||||
|
_global_state.tokenizer_manager.flush_cache()
|
||||||
|
return Response(
|
||||||
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
||||||
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||||
|
async def start_profile_async():
|
||||||
|
"""Start profiling."""
|
||||||
|
_global_state.tokenizer_manager.start_profile()
|
||||||
|
return Response(
|
||||||
|
content="Start profiling.\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
||||||
|
async def stop_profile_async():
|
||||||
|
"""Stop profiling."""
|
||||||
|
_global_state.tokenizer_manager.stop_profile()
|
||||||
|
return Response(
|
||||||
|
content="Stop profiling. This will take some time.\n",
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/update_weights_from_disk")
|
||||||
|
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||||
|
"""Update the weights from disk in-place without re-launching the server."""
|
||||||
|
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
|
||||||
|
obj, request
|
||||||
|
)
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
if success:
|
||||||
|
return ORJSONResponse(
|
||||||
|
content,
|
||||||
|
status_code=HTTPStatus.OK,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(
|
||||||
|
content,
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/init_weights_update_group")
|
||||||
|
async def init_weights_update_group(
|
||||||
|
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||||
|
):
|
||||||
|
"""Initialize the parameter update group."""
|
||||||
|
success, message = await _global_state.tokenizer_manager.init_weights_update_group(
|
||||||
|
obj, request
|
||||||
|
)
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
if success:
|
||||||
|
return ORJSONResponse(content, status_code=200)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/update_weights_from_distributed")
|
||||||
|
async def update_weights_from_distributed(
|
||||||
|
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
||||||
|
):
|
||||||
|
"""Update model parameter from distributed online."""
|
||||||
|
success, message = (
|
||||||
|
await _global_state.tokenizer_manager.update_weights_from_distributed(
|
||||||
|
obj, request
|
||||||
|
)
|
||||||
|
)
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
if success:
|
||||||
|
return ORJSONResponse(content, status_code=200)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||||
|
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||||
|
"""Get model parameter by name."""
|
||||||
|
try:
|
||||||
|
ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request)
|
||||||
|
if ret is None:
|
||||||
|
return _create_error_response("Get parameter by name failed")
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(ret, status_code=200)
|
||||||
|
except Exception as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
|
||||||
|
async def release_memory_occupation(
|
||||||
|
obj: ReleaseMemoryOccupationReqInput, request: Request
|
||||||
|
):
|
||||||
|
"""Release GPU occupation temporarily"""
|
||||||
|
try:
|
||||||
|
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
|
||||||
|
except Exception as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
|
||||||
|
async def resume_memory_occupation(
|
||||||
|
obj: ResumeMemoryOccupationReqInput, request: Request
|
||||||
|
):
|
||||||
|
"""Resume GPU occupation"""
|
||||||
|
try:
|
||||||
|
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
|
||||||
|
except Exception as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||||
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||||
|
"""Open a session, and return its unique session id."""
|
||||||
|
try:
|
||||||
|
session_id = await _global_state.tokenizer_manager.open_session(obj, request)
|
||||||
|
if session_id is None:
|
||||||
|
raise Exception(
|
||||||
|
"Failed to open the session. Check if a session with the same id is still open."
|
||||||
|
)
|
||||||
|
return session_id
|
||||||
|
except Exception as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||||
|
async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||||
|
"""Close the session"""
|
||||||
|
try:
|
||||||
|
await _global_state.tokenizer_manager.close_session(obj, request)
|
||||||
|
return Response(status_code=200)
|
||||||
|
except Exception as e:
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||||
|
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||||
|
"""Close the session"""
|
||||||
|
_global_state.tokenizer_manager.configure_logging(obj)
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
##### OpenAI-compatible API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def openai_v1_completions(raw_request: Request):
|
||||||
|
return await v1_completions(_global_state.tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def openai_v1_chat_completions(raw_request: Request):
|
||||||
|
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
||||||
|
async def openai_v1_embeddings(raw_request: Request):
|
||||||
|
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models", response_class=ORJSONResponse)
|
||||||
|
def available_models():
|
||||||
|
"""Show available models."""
|
||||||
|
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
||||||
|
model_cards = []
|
||||||
|
for served_model_name in served_model_names:
|
||||||
|
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
||||||
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/files")
|
||||||
|
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||||
|
return await v1_files_create(
|
||||||
|
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/v1/files/{file_id}")
|
||||||
|
async def delete_file(file_id: str):
|
||||||
|
# https://platform.openai.com/docs/api-reference/files/delete
|
||||||
|
return await v1_delete_file(file_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/batches")
|
||||||
|
async def openai_v1_batches(raw_request: Request):
|
||||||
|
return await v1_batches(_global_state.tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/batches/{batch_id}/cancel")
|
||||||
|
async def cancel_batches(batch_id: str):
|
||||||
|
# https://platform.openai.com/docs/api-reference/batch/cancel
|
||||||
|
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/batches/{batch_id}")
|
||||||
|
async def retrieve_batch(batch_id: str):
|
||||||
|
return await v1_retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/files/{file_id}")
|
||||||
|
async def retrieve_file(file_id: str):
|
||||||
|
# https://platform.openai.com/docs/api-reference/files/retrieve
|
||||||
|
return await v1_retrieve_file(file_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/files/{file_id}/content")
|
||||||
|
async def retrieve_file_content(file_id: str):
|
||||||
|
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||||
|
return await v1_retrieve_file_content(file_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_error_response(e):
|
||||||
|
return ORJSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_server(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Launch SRT (SGLang Runtime) Server.
|
||||||
|
|
||||||
|
The SRT server consists of an HTTP server and an SRT engine.
|
||||||
|
|
||||||
|
- HTTP server: A FastAPI server that routes requests to the engine.
|
||||||
|
- The engine consists of three components:
|
||||||
|
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
||||||
|
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
||||||
|
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
||||||
|
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||||
|
"""
|
||||||
|
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||||
|
set_global_state(
|
||||||
|
_GlobalState(
|
||||||
|
tokenizer_manager=tokenizer_manager,
|
||||||
|
scheduler_info=scheduler_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add api key authorization
|
||||||
|
if server_args.api_key:
|
||||||
|
add_api_key_middleware(app, server_args.api_key)
|
||||||
|
|
||||||
|
# Add prometheus middleware
|
||||||
|
if server_args.enable_metrics:
|
||||||
|
add_prometheus_middleware(app)
|
||||||
|
enable_func_timer()
|
||||||
|
|
||||||
|
# Send a warmup request
|
||||||
|
t = threading.Thread(
|
||||||
|
target=_wait_and_warmup,
|
||||||
|
args=(
|
||||||
|
server_args,
|
||||||
|
pipe_finish_writer,
|
||||||
|
_global_state.tokenizer_manager.image_token_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Update logging configs
|
||||||
|
set_uvicorn_logging_configs()
|
||||||
|
|
||||||
|
# Listen for HTTP requests
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=server_args.host,
|
||||||
|
port=server_args.port,
|
||||||
|
log_level=server_args.log_level_http or server_args.log_level,
|
||||||
|
timeout_keep_alive=5,
|
||||||
|
loop="uvloop",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||||
|
headers = {}
|
||||||
|
url = server_args.url()
|
||||||
|
if server_args.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
||||||
|
|
||||||
|
# Wait until the server is launched
|
||||||
|
success = False
|
||||||
|
for _ in range(120):
|
||||||
|
time.sleep(1)
|
||||||
|
try:
|
||||||
|
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||||
|
assert res.status_code == 200, f"{res=}, {res.text=}"
|
||||||
|
success = True
|
||||||
|
break
|
||||||
|
except (AssertionError, requests.exceptions.RequestException):
|
||||||
|
last_traceback = get_exception_traceback()
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send(last_traceback)
|
||||||
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
return
|
||||||
|
|
||||||
|
model_info = res.json()
|
||||||
|
|
||||||
|
# Send a warmup request
|
||||||
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||||
|
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||||
|
json_data = {
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if server_args.skip_tokenizer_init:
|
||||||
|
json_data["input_ids"] = [10, 11, 12]
|
||||||
|
else:
|
||||||
|
json_data["text"] = "The capital city of France is"
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _ in range(server_args.dp_size):
|
||||||
|
res = requests.post(
|
||||||
|
url + request_name,
|
||||||
|
json=json_data,
|
||||||
|
headers=headers,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
assert res.status_code == 200, f"{res}"
|
||||||
|
except Exception:
|
||||||
|
last_traceback = get_exception_traceback()
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send(last_traceback)
|
||||||
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
return
|
||||||
|
|
||||||
|
# Debug print
|
||||||
|
# logger.info(f"{res.json()=}")
|
||||||
|
|
||||||
|
logger.info("The server is fired up and ready to roll!")
|
||||||
|
if pipe_finish_writer is not None:
|
||||||
|
pipe_finish_writer.send("ready")
|
||||||
|
|
||||||
|
if server_args.delete_ckpt_after_loading:
|
||||||
|
delete_directory(server_args.model_path)
|
||||||
@@ -22,7 +22,6 @@ from enum import Enum
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store states
|
# Store states
|
||||||
self.to_create_loop = True
|
self.no_create_loop = False
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
self.dump_requests_folder = "" # By default do not dump
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
self.dump_requests_threshold = 1000
|
self.dump_requests_threshold = 1000
|
||||||
@@ -684,7 +684,6 @@ class TokenizerManager:
|
|||||||
async def close_session(
|
async def close_session(
|
||||||
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
assert not self.to_create_loop, "close session should not be the first request"
|
|
||||||
await self.send_to_scheduler.send_pyobj(obj)
|
await self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
|
||||||
def configure_logging(self, obj: ConfigureLoggingReq):
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||||
@@ -713,10 +712,10 @@ class TokenizerManager:
|
|||||||
return background_tasks
|
return background_tasks
|
||||||
|
|
||||||
def auto_create_handle_loop(self):
|
def auto_create_handle_loop(self):
|
||||||
if not self.to_create_loop:
|
if self.no_create_loop:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.to_create_loop = False
|
self.no_create_loop = True
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
self.asyncio_tasks.add(
|
self.asyncio_tasks.add(
|
||||||
loop.create_task(print_exception_wrapper(self.handle_loop))
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||||
|
|||||||
@@ -11,949 +11,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""
|
|
||||||
The entry point of inference server.
|
|
||||||
SRT = SGLang Runtime.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
# Some shortcuts for backward compatbility.
|
||||||
import atexit
|
# They will be removed in new versions.
|
||||||
import dataclasses
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
import json
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
import logging
|
|
||||||
import multiprocessing as mp
|
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from http import HTTPStatus
|
|
||||||
from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import orjson
|
|
||||||
import requests
|
|
||||||
import uvicorn
|
|
||||||
import uvloop
|
|
||||||
from fastapi import FastAPI, File, Form, Request, UploadFile
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
||||||
|
|
||||||
from sglang.srt.managers.data_parallel_controller import (
|
|
||||||
run_data_parallel_controller_process,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
|
||||||
from sglang.srt.managers.io_struct import (
|
|
||||||
CloseSessionReqInput,
|
|
||||||
ConfigureLoggingReq,
|
|
||||||
EmbeddingReqInput,
|
|
||||||
GenerateReqInput,
|
|
||||||
GetWeightsByNameReqInput,
|
|
||||||
InitWeightsUpdateGroupReqInput,
|
|
||||||
OpenSessionReqInput,
|
|
||||||
ReleaseMemoryOccupationReqInput,
|
|
||||||
ResumeMemoryOccupationReqInput,
|
|
||||||
UpdateWeightFromDiskReqInput,
|
|
||||||
UpdateWeightsFromDistributedReqInput,
|
|
||||||
UpdateWeightsFromTensorReqInput,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency
|
|
||||||
from sglang.srt.openai_api.adapter import (
|
|
||||||
load_chat_template_for_openai_api,
|
|
||||||
v1_batches,
|
|
||||||
v1_cancel_batch,
|
|
||||||
v1_chat_completions,
|
|
||||||
v1_completions,
|
|
||||||
v1_delete_file,
|
|
||||||
v1_embeddings,
|
|
||||||
v1_files_create,
|
|
||||||
v1_retrieve_batch,
|
|
||||||
v1_retrieve_file,
|
|
||||||
v1_retrieve_file_content,
|
|
||||||
)
|
|
||||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
||||||
from sglang.srt.utils import (
|
|
||||||
MultiprocessingSerializer,
|
|
||||||
add_api_key_middleware,
|
|
||||||
add_prometheus_middleware,
|
|
||||||
assert_pkg_version,
|
|
||||||
configure_logger,
|
|
||||||
delete_directory,
|
|
||||||
kill_process_tree,
|
|
||||||
maybe_set_triton_cache_manager,
|
|
||||||
prepare_model_and_tokenizer,
|
|
||||||
set_prometheus_multiproc_dir,
|
|
||||||
set_ulimit,
|
|
||||||
set_uvicorn_logging_configs,
|
|
||||||
)
|
|
||||||
from sglang.utils import get_exception_traceback
|
|
||||||
from sglang.version import __version__
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
||||||
|
|
||||||
# Fast API
|
|
||||||
app = FastAPI()
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer_manager: TokenizerManager = None
|
|
||||||
scheduler_info: Dict = None
|
|
||||||
|
|
||||||
|
|
||||||
##### Native API endpoints #####
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health() -> Response:
|
|
||||||
"""Check the health of the http server."""
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health_generate")
|
|
||||||
async def health_generate(request: Request) -> Response:
|
|
||||||
"""Check the health of the inference server by generating one token."""
|
|
||||||
|
|
||||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
|
||||||
|
|
||||||
if tokenizer_manager.is_generation:
|
|
||||||
gri = GenerateReqInput(
|
|
||||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
gri = EmbeddingReqInput(
|
|
||||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for _ in tokenizer_manager.generate_request(gri, request):
|
|
||||||
break
|
|
||||||
return Response(status_code=200)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(e)
|
|
||||||
return Response(status_code=503)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_model_info")
|
|
||||||
async def get_model_info():
|
|
||||||
"""Get the model information."""
|
|
||||||
result = {
|
|
||||||
"model_path": tokenizer_manager.model_path,
|
|
||||||
"tokenizer_path": tokenizer_manager.server_args.tokenizer_path,
|
|
||||||
"is_generation": tokenizer_manager.is_generation,
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_server_info")
|
|
||||||
async def get_server_info():
|
|
||||||
return {
|
|
||||||
**dataclasses.asdict(tokenizer_manager.server_args),
|
|
||||||
**scheduler_info,
|
|
||||||
"version": __version__,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
|
||||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
|
||||||
@time_func_latency
|
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
|
||||||
"""Handle a generate request."""
|
|
||||||
if obj.stream:
|
|
||||||
|
|
||||||
async def stream_results() -> AsyncIterator[bytes]:
|
|
||||||
try:
|
|
||||||
async for out in tokenizer_manager.generate_request(obj, request):
|
|
||||||
yield b"data: " + orjson.dumps(
|
|
||||||
out, option=orjson.OPT_NON_STR_KEYS
|
|
||||||
) + b"\n\n"
|
|
||||||
except ValueError as e:
|
|
||||||
out = {"error": {"message": str(e)}}
|
|
||||||
yield b"data: " + orjson.dumps(
|
|
||||||
out, option=orjson.OPT_NON_STR_KEYS
|
|
||||||
) + b"\n\n"
|
|
||||||
yield b"data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_results(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
background=tokenizer_manager.create_abort_task(obj),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
|
||||||
return ret
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"Error: {e}")
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
|
||||||
@time_func_latency
|
|
||||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|
||||||
"""Handle an embedding request."""
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
|
||||||
return ret
|
|
||||||
except ValueError as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
|
||||||
@time_func_latency
|
|
||||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
||||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
|
||||||
return ret
|
|
||||||
except ValueError as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/flush_cache")
|
|
||||||
async def flush_cache():
|
|
||||||
"""Flush the radix cache."""
|
|
||||||
tokenizer_manager.flush_cache()
|
|
||||||
return Response(
|
|
||||||
content="Cache flushed.\nPlease check backend logs for more details. "
|
|
||||||
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/start_profile", methods=["GET", "POST"])
|
|
||||||
async def start_profile_async():
|
|
||||||
"""Start profiling."""
|
|
||||||
tokenizer_manager.start_profile()
|
|
||||||
return Response(
|
|
||||||
content="Start profiling.\n",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
|
||||||
async def stop_profile_async():
|
|
||||||
"""Stop profiling."""
|
|
||||||
tokenizer_manager.stop_profile()
|
|
||||||
return Response(
|
|
||||||
content="Stop profiling. This will take some time.\n",
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights_from_disk")
|
|
||||||
@time_func_latency
|
|
||||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
|
||||||
"""Update the weights from disk in-place without re-launching the server."""
|
|
||||||
success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
|
|
||||||
content = {"success": success, "message": message}
|
|
||||||
if success:
|
|
||||||
return ORJSONResponse(
|
|
||||||
content,
|
|
||||||
status_code=HTTPStatus.OK,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ORJSONResponse(
|
|
||||||
content,
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/init_weights_update_group")
|
|
||||||
async def init_weights_update_group(
|
|
||||||
obj: InitWeightsUpdateGroupReqInput, request: Request
|
|
||||||
):
|
|
||||||
"""Initialize the parameter update group."""
|
|
||||||
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
|
|
||||||
content = {"success": success, "message": message}
|
|
||||||
if success:
|
|
||||||
return ORJSONResponse(content, status_code=200)
|
|
||||||
else:
|
|
||||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weights_from_distributed")
|
|
||||||
async def update_weights_from_distributed(
|
|
||||||
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
|
||||||
):
|
|
||||||
"""Update model parameter from distributed online."""
|
|
||||||
success, message = await tokenizer_manager.update_weights_from_distributed(
|
|
||||||
obj, request
|
|
||||||
)
|
|
||||||
content = {"success": success, "message": message}
|
|
||||||
if success:
|
|
||||||
return ORJSONResponse(content, status_code=200)
|
|
||||||
else:
|
|
||||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
|
||||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
|
||||||
"""Get model parameter by name."""
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
|
||||||
if ret is None:
|
|
||||||
return _create_error_response("Get parameter by name failed")
|
|
||||||
else:
|
|
||||||
return ORJSONResponse(ret, status_code=200)
|
|
||||||
except Exception as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
|
|
||||||
async def release_memory_occupation(
|
|
||||||
obj: ReleaseMemoryOccupationReqInput, request: Request
|
|
||||||
):
|
|
||||||
"""Release GPU occupation temporarily"""
|
|
||||||
try:
|
|
||||||
await tokenizer_manager.release_memory_occupation(obj, request)
|
|
||||||
except Exception as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
|
|
||||||
async def resume_memory_occupation(
|
|
||||||
obj: ResumeMemoryOccupationReqInput, request: Request
|
|
||||||
):
|
|
||||||
"""Resume GPU occupation"""
|
|
||||||
try:
|
|
||||||
await tokenizer_manager.resume_memory_occupation(obj, request)
|
|
||||||
except Exception as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
|
||||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
|
||||||
"""Open a session, and return its unique session id."""
|
|
||||||
try:
|
|
||||||
session_id = await tokenizer_manager.open_session(obj, request)
|
|
||||||
if session_id is None:
|
|
||||||
raise Exception(
|
|
||||||
"Failed to open the session. Check if a session with the same id is still open."
|
|
||||||
)
|
|
||||||
return session_id
|
|
||||||
except Exception as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
|
||||||
async def close_session(obj: CloseSessionReqInput, request: Request):
|
|
||||||
"""Close the session"""
|
|
||||||
try:
|
|
||||||
await tokenizer_manager.close_session(obj, request)
|
|
||||||
return Response(status_code=200)
|
|
||||||
except Exception as e:
|
|
||||||
return _create_error_response(e)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
|
||||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
|
||||||
"""Close the session"""
|
|
||||||
tokenizer_manager.configure_logging(obj)
|
|
||||||
return Response(status_code=200)
|
|
||||||
|
|
||||||
|
|
||||||
##### OpenAI-compatible API endpoints #####
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
|
||||||
@time_func_latency
|
|
||||||
async def openai_v1_completions(raw_request: Request):
|
|
||||||
return await v1_completions(tokenizer_manager, raw_request)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
@time_func_latency
|
|
||||||
async def openai_v1_chat_completions(raw_request: Request):
|
|
||||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
|
||||||
@time_func_latency
|
|
||||||
async def openai_v1_embeddings(raw_request: Request):
|
|
||||||
response = await v1_embeddings(tokenizer_manager, raw_request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_class=ORJSONResponse)
|
|
||||||
def available_models():
|
|
||||||
"""Show available models."""
|
|
||||||
served_model_names = [tokenizer_manager.served_model_name]
|
|
||||||
model_cards = []
|
|
||||||
for served_model_name in served_model_names:
|
|
||||||
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
|
||||||
return ModelList(data=model_cards)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/files")
|
|
||||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
|
||||||
return await v1_files_create(
|
|
||||||
file, purpose, tokenizer_manager.server_args.file_storage_pth
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/v1/files/{file_id}")
|
|
||||||
async def delete_file(file_id: str):
|
|
||||||
# https://platform.openai.com/docs/api-reference/files/delete
|
|
||||||
return await v1_delete_file(file_id)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/batches")
|
|
||||||
async def openai_v1_batches(raw_request: Request):
|
|
||||||
return await v1_batches(tokenizer_manager, raw_request)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/batches/{batch_id}/cancel")
|
|
||||||
async def cancel_batches(batch_id: str):
|
|
||||||
# https://platform.openai.com/docs/api-reference/batch/cancel
|
|
||||||
return await v1_cancel_batch(tokenizer_manager, batch_id)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/batches/{batch_id}")
|
|
||||||
async def retrieve_batch(batch_id: str):
|
|
||||||
return await v1_retrieve_batch(batch_id)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/files/{file_id}")
|
|
||||||
async def retrieve_file(file_id: str):
|
|
||||||
# https://platform.openai.com/docs/api-reference/files/retrieve
|
|
||||||
return await v1_retrieve_file(file_id)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/files/{file_id}/content")
|
|
||||||
async def retrieve_file_content(file_id: str):
|
|
||||||
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
|
||||||
return await v1_retrieve_file_content(file_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_error_response(e):
|
|
||||||
return ORJSONResponse(
|
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_engine(
|
|
||||||
server_args: ServerArgs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
|
||||||
"""
|
|
||||||
|
|
||||||
global tokenizer_manager
|
|
||||||
global scheduler_info
|
|
||||||
|
|
||||||
# Configure global environment
|
|
||||||
configure_logger(server_args)
|
|
||||||
server_args.check_server_args()
|
|
||||||
_set_envs_and_config(server_args)
|
|
||||||
|
|
||||||
# Allocate ports for inter-process communications
|
|
||||||
port_args = PortArgs.init_new(server_args)
|
|
||||||
logger.info(f"{server_args=}")
|
|
||||||
|
|
||||||
# If using model from www.modelscope.cn, first download the model.
|
|
||||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
|
||||||
server_args.model_path, server_args.tokenizer_path
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler_procs = []
|
|
||||||
if server_args.dp_size == 1:
|
|
||||||
# Launch tensor parallel scheduler processes
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
||||||
enable=server_args.enable_memory_saver
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler_pipe_readers = []
|
|
||||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
|
||||||
tp_rank_range = range(
|
|
||||||
tp_size_per_node * server_args.node_rank,
|
|
||||||
tp_size_per_node * (server_args.node_rank + 1),
|
|
||||||
)
|
|
||||||
for tp_rank in tp_rank_range:
|
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
|
||||||
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
|
|
||||||
proc = mp.Process(
|
|
||||||
target=run_scheduler_process,
|
|
||||||
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
|
||||||
)
|
|
||||||
with memory_saver_adapter.configure_subprocess():
|
|
||||||
proc.start()
|
|
||||||
scheduler_procs.append(proc)
|
|
||||||
scheduler_pipe_readers.append(reader)
|
|
||||||
else:
|
|
||||||
# Launch the data parallel controller
|
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
|
||||||
scheduler_pipe_readers = [reader]
|
|
||||||
proc = mp.Process(
|
|
||||||
target=run_data_parallel_controller_process,
|
|
||||||
args=(server_args, port_args, writer),
|
|
||||||
)
|
|
||||||
proc.start()
|
|
||||||
scheduler_procs.append(proc)
|
|
||||||
|
|
||||||
if server_args.node_rank >= 1:
|
|
||||||
# In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
|
|
||||||
# so they can just wait here.
|
|
||||||
|
|
||||||
for reader in scheduler_pipe_readers:
|
|
||||||
data = reader.recv()
|
|
||||||
assert data["status"] == "ready"
|
|
||||||
|
|
||||||
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
|
||||||
# When using `Engine` as a Python API, we don't want to block here.
|
|
||||||
return
|
|
||||||
|
|
||||||
for proc in scheduler_procs:
|
|
||||||
proc.join()
|
|
||||||
logger.error(
|
|
||||||
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Launch detokenizer process
|
|
||||||
detoken_proc = mp.Process(
|
|
||||||
target=run_detokenizer_process,
|
|
||||||
args=(
|
|
||||||
server_args,
|
|
||||||
port_args,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
detoken_proc.start()
|
|
||||||
|
|
||||||
# Launch tokenizer process
|
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
|
||||||
if server_args.chat_template:
|
|
||||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
|
||||||
|
|
||||||
# Wait for model to finish loading
|
|
||||||
scheduler_infos = []
|
|
||||||
for i in range(len(scheduler_pipe_readers)):
|
|
||||||
try:
|
|
||||||
data = scheduler_pipe_readers[i].recv()
|
|
||||||
except EOFError as e:
|
|
||||||
logger.exception(e)
|
|
||||||
logger.error(
|
|
||||||
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
|
||||||
)
|
|
||||||
scheduler_procs[i].join()
|
|
||||||
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if data["status"] != "ready":
|
|
||||||
raise RuntimeError(
|
|
||||||
"Initialization failed. Please see the error messages above."
|
|
||||||
)
|
|
||||||
scheduler_infos.append(data)
|
|
||||||
|
|
||||||
# Assume all schedulers have same scheduler_info
|
|
||||||
scheduler_info = scheduler_infos[0]
|
|
||||||
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
|
||||||
|
|
||||||
|
|
||||||
def launch_server(
|
|
||||||
server_args: ServerArgs,
|
|
||||||
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Launch SRT (SGLang Runtime) Server
|
|
||||||
|
|
||||||
The SRT server consists of an HTTP server and the SRT engine.
|
|
||||||
|
|
||||||
1. HTTP server: A FastAPI server that routes requests to the engine.
|
|
||||||
2. SRT engine:
|
|
||||||
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
|
||||||
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
|
||||||
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
1. The HTTP server and TokenizerManager both run in the main process.
|
|
||||||
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
|
||||||
"""
|
|
||||||
launch_engine(server_args=server_args)
|
|
||||||
|
|
||||||
# Add api key authorization
|
|
||||||
if server_args.api_key:
|
|
||||||
add_api_key_middleware(app, server_args.api_key)
|
|
||||||
|
|
||||||
# Add prometheus middleware
|
|
||||||
if server_args.enable_metrics:
|
|
||||||
add_prometheus_middleware(app)
|
|
||||||
enable_func_timer()
|
|
||||||
|
|
||||||
# Send a warmup request
|
|
||||||
t = threading.Thread(
|
|
||||||
target=_wait_and_warmup,
|
|
||||||
args=(
|
|
||||||
server_args,
|
|
||||||
pipe_finish_writer,
|
|
||||||
tokenizer_manager.image_token_id,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Update logging configs
|
|
||||||
set_uvicorn_logging_configs()
|
|
||||||
|
|
||||||
# Listen for HTTP requests
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host=server_args.host,
|
|
||||||
port=server_args.port,
|
|
||||||
log_level=server_args.log_level_http or server_args.log_level,
|
|
||||||
timeout_keep_alive=5,
|
|
||||||
loop="uvloop",
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
|
|
||||||
def _set_envs_and_config(server_args: ServerArgs):
|
|
||||||
# Set global environments
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
||||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
|
||||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
|
||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
|
||||||
|
|
||||||
# Set prometheus env vars
|
|
||||||
if server_args.enable_metrics:
|
|
||||||
set_prometheus_multiproc_dir()
|
|
||||||
|
|
||||||
# Set ulimit
|
|
||||||
set_ulimit()
|
|
||||||
|
|
||||||
# Fix triton bugs
|
|
||||||
if server_args.tp_size * server_args.dp_size > 1:
|
|
||||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
|
||||||
maybe_set_triton_cache_manager()
|
|
||||||
|
|
||||||
# Check flashinfer version
|
|
||||||
if server_args.attention_backend == "flashinfer":
|
|
||||||
assert_pkg_version(
|
|
||||||
"flashinfer",
|
|
||||||
"0.1.6",
|
|
||||||
"Please uninstall the old version and "
|
|
||||||
"reinstall the latest version by following the instructions "
|
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register the signal handler.
|
|
||||||
# The child processes will send SIGQUIT to this process when any error happens
|
|
||||||
# This process then clean up the whole process tree
|
|
||||||
def sigquit_handler(signum, frame):
|
|
||||||
logger.error(
|
|
||||||
"Received sigquit from a child proces. It usually means the child failed."
|
|
||||||
)
|
|
||||||
kill_process_tree(os.getpid())
|
|
||||||
|
|
||||||
signal.signal(signal.SIGQUIT, sigquit_handler)
|
|
||||||
|
|
||||||
# Set mp start method
|
|
||||||
mp.set_start_method("spawn", force=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
|
||||||
headers = {}
|
|
||||||
url = server_args.url()
|
|
||||||
if server_args.api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
|
||||||
|
|
||||||
# Wait until the server is launched
|
|
||||||
success = False
|
|
||||||
for _ in range(120):
|
|
||||||
time.sleep(1)
|
|
||||||
try:
|
|
||||||
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
|
||||||
assert res.status_code == 200, f"{res=}, {res.text=}"
|
|
||||||
success = True
|
|
||||||
break
|
|
||||||
except (AssertionError, requests.exceptions.RequestException):
|
|
||||||
last_traceback = get_exception_traceback()
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
if pipe_finish_writer is not None:
|
|
||||||
pipe_finish_writer.send(last_traceback)
|
|
||||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
|
||||||
kill_process_tree(os.getpid())
|
|
||||||
return
|
|
||||||
|
|
||||||
model_info = res.json()
|
|
||||||
|
|
||||||
# Send a warmup request
|
|
||||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
|
||||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
|
||||||
json_data = {
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if server_args.skip_tokenizer_init:
|
|
||||||
json_data["input_ids"] = [10, 11, 12]
|
|
||||||
else:
|
|
||||||
json_data["text"] = "The capital city of France is"
|
|
||||||
|
|
||||||
try:
|
|
||||||
for _ in range(server_args.dp_size):
|
|
||||||
res = requests.post(
|
|
||||||
url + request_name,
|
|
||||||
json=json_data,
|
|
||||||
headers=headers,
|
|
||||||
timeout=600,
|
|
||||||
)
|
|
||||||
assert res.status_code == 200, f"{res}"
|
|
||||||
except Exception:
|
|
||||||
last_traceback = get_exception_traceback()
|
|
||||||
if pipe_finish_writer is not None:
|
|
||||||
pipe_finish_writer.send(last_traceback)
|
|
||||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
|
||||||
kill_process_tree(os.getpid())
|
|
||||||
return
|
|
||||||
|
|
||||||
# Debug print
|
|
||||||
# logger.info(f"{res.json()=}")
|
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
|
||||||
if pipe_finish_writer is not None:
|
|
||||||
pipe_finish_writer.send("ready")
|
|
||||||
|
|
||||||
if server_args.delete_ckpt_after_loading:
|
|
||||||
delete_directory(server_args.model_path)
|
|
||||||
|
|
||||||
|
|
||||||
STREAM_END_SYMBOL = b"data: [DONE]"
|
|
||||||
STREAM_CHUNK_START_SYMBOL = b"data:"
|
|
||||||
|
|
||||||
|
|
||||||
class Engine:
|
|
||||||
"""
|
|
||||||
SRT Engine without an HTTP server layer.
|
|
||||||
|
|
||||||
This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
|
|
||||||
launching the HTTP server adds unnecessary complexity or overhead,
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, log_level: str = "error", *args, **kwargs):
|
|
||||||
"""See the arguments in server_args.py::ServerArgs"""
|
|
||||||
|
|
||||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
|
||||||
atexit.register(self.shutdown)
|
|
||||||
|
|
||||||
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
|
||||||
launch_engine(server_args=server_args)
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
|
||||||
prompt: Optional[Union[List[str], str]] = None,
|
|
||||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
|
||||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
|
||||||
stream: bool = False,
|
|
||||||
):
|
|
||||||
obj = GenerateReqInput(
|
|
||||||
text=prompt,
|
|
||||||
input_ids=input_ids,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
logprob_start_len=logprob_start_len,
|
|
||||||
top_logprobs_num=top_logprobs_num,
|
|
||||||
lora_path=lora_path,
|
|
||||||
stream=stream,
|
|
||||||
custom_logit_processor=custom_logit_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get the current event loop
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
ret = loop.run_until_complete(generate_request(obj, None))
|
|
||||||
|
|
||||||
if stream is True:
|
|
||||||
|
|
||||||
def generator_wrapper():
|
|
||||||
offset = 0
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
generator = ret.body_iterator
|
|
||||||
while True:
|
|
||||||
chunk = loop.run_until_complete(generator.__anext__())
|
|
||||||
|
|
||||||
if chunk.startswith(STREAM_END_SYMBOL):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
|
||||||
data["text"] = data["text"][offset:]
|
|
||||||
offset += len(data["text"])
|
|
||||||
yield data
|
|
||||||
|
|
||||||
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
|
||||||
# however, it allows to wrap the generator as a subfunction and return
|
|
||||||
return generator_wrapper()
|
|
||||||
else:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async def async_generate(
|
|
||||||
self,
|
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
|
||||||
prompt: Optional[Union[List[str], str]] = None,
|
|
||||||
sampling_params: Optional[Dict] = None,
|
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
|
||||||
custom_logit_processor: Optional[Union[str, List[str]]] = None,
|
|
||||||
stream: bool = False,
|
|
||||||
):
|
|
||||||
obj = GenerateReqInput(
|
|
||||||
text=prompt,
|
|
||||||
input_ids=input_ids,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
return_logprob=return_logprob,
|
|
||||||
logprob_start_len=logprob_start_len,
|
|
||||||
top_logprobs_num=top_logprobs_num,
|
|
||||||
lora_path=lora_path,
|
|
||||||
stream=stream,
|
|
||||||
custom_logit_processor=custom_logit_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = await generate_request(obj, None)
|
|
||||||
|
|
||||||
if stream is True:
|
|
||||||
generator = ret.body_iterator
|
|
||||||
|
|
||||||
async def generator_wrapper():
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
chunk = await generator.__anext__()
|
|
||||||
|
|
||||||
if chunk.startswith(STREAM_END_SYMBOL):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
|
||||||
data["text"] = data["text"][offset:]
|
|
||||||
offset += len(data["text"])
|
|
||||||
yield data
|
|
||||||
|
|
||||||
return generator_wrapper()
|
|
||||||
else:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
kill_process_tree(os.getpid(), include_parent=False)
|
|
||||||
|
|
||||||
def get_tokenizer(self):
|
|
||||||
global tokenizer_manager
|
|
||||||
|
|
||||||
if tokenizer_manager is None:
|
|
||||||
raise ReferenceError("Tokenizer Manager is not initialized.")
|
|
||||||
else:
|
|
||||||
return tokenizer_manager.tokenizer
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
|
||||||
):
|
|
||||||
obj = EmbeddingReqInput(text=prompt)
|
|
||||||
|
|
||||||
# get the current event loop
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(encode_request(obj, None))
|
|
||||||
|
|
||||||
def start_profile(self):
|
|
||||||
tokenizer_manager.start_profile()
|
|
||||||
|
|
||||||
def stop_profile(self):
|
|
||||||
tokenizer_manager.stop_profile()
|
|
||||||
|
|
||||||
def get_server_info(self):
|
|
||||||
return {
|
|
||||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
|
||||||
**scheduler_info,
|
|
||||||
"version": __version__,
|
|
||||||
}
|
|
||||||
|
|
||||||
def init_weights_update_group(
|
|
||||||
self,
|
|
||||||
master_address: str,
|
|
||||||
master_port: int,
|
|
||||||
rank_offset: int,
|
|
||||||
world_size: int,
|
|
||||||
group_name: str,
|
|
||||||
backend: str = "nccl",
|
|
||||||
):
|
|
||||||
"""Initialize parameter update group."""
|
|
||||||
obj = InitWeightsUpdateGroupReqInput(
|
|
||||||
master_address=master_address,
|
|
||||||
master_port=master_port,
|
|
||||||
rank_offset=rank_offset,
|
|
||||||
world_size=world_size,
|
|
||||||
group_name=group_name,
|
|
||||||
backend=backend,
|
|
||||||
)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(
|
|
||||||
tokenizer_manager.init_weights_update_group(obj, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_weights_from_distributed(self, name, dtype, shape):
|
|
||||||
"""Update weights from distributed source."""
|
|
||||||
obj = UpdateWeightsFromDistributedReqInput(
|
|
||||||
name=name,
|
|
||||||
dtype=dtype,
|
|
||||||
shape=shape,
|
|
||||||
)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(
|
|
||||||
tokenizer_manager.update_weights_from_distributed(obj, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
|
||||||
"""Update weights from distributed source."""
|
|
||||||
obj = UpdateWeightsFromTensorReqInput(
|
|
||||||
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
|
|
||||||
)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(
|
|
||||||
tokenizer_manager.update_weights_from_tensor(obj, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weights_by_name(self, name, truncate_size=100):
|
|
||||||
"""Get weights by parameter name."""
|
|
||||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
|
||||||
|
|
||||||
def release_memory_occupation(self):
|
|
||||||
"""Release GPU occupation temporarily"""
|
|
||||||
obj = ReleaseMemoryOccupationReqInput()
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
|
|
||||||
|
|
||||||
def resume_memory_occupation(self):
|
|
||||||
"""Resume GPU occupation"""
|
|
||||||
obj = ResumeMemoryOccupationReqInput()
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
import json
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -22,8 +21,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.server import Engine
|
|
||||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
||||||
|
|
||||||
DEFAULT_PROMPTS = [
|
DEFAULT_PROMPTS = [
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import requests
|
|||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
from sglang_router.launch_router import RouterArgs, launch_router
|
from sglang_router.launch_router import RouterArgs, launch_router
|
||||||
|
|
||||||
from sglang.srt.server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import is_port_available
|
from sglang.srt.utils import is_port_available
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ class TestEnableMetrics(unittest.TestCase):
|
|||||||
"sglang:gen_throughput",
|
"sglang:gen_throughput",
|
||||||
"sglang:num_queue_reqs",
|
"sglang:num_queue_reqs",
|
||||||
"sglang:cache_hit_rate",
|
"sglang:cache_hit_rate",
|
||||||
"sglang:func_latency_seconds",
|
|
||||||
"sglang:prompt_tokens_total",
|
"sglang:prompt_tokens_total",
|
||||||
"sglang:generation_tokens_total",
|
"sglang:generation_tokens_total",
|
||||||
"sglang:num_requests_total",
|
"sglang:num_requests_total",
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def parse_models(model_string):
|
|||||||
return [model.strip() for model in model_string.split(",") if model.strip()]
|
return [model.strip() for model in model_string.split(",") if model.strip()]
|
||||||
|
|
||||||
|
|
||||||
def launch_server(base_url, model, is_fp8, is_tp2):
|
def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2):
|
||||||
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
||||||
if is_fp8:
|
if is_fp8:
|
||||||
if "Llama-3" in model or "gemma-2" in model:
|
if "Llama-3" in model or "gemma-2" in model:
|
||||||
@@ -148,7 +148,9 @@ class TestNightlyGsm8KEval(unittest.TestCase):
|
|||||||
for model_group, is_fp8, is_tp2 in self.model_groups:
|
for model_group, is_fp8, is_tp2 in self.model_groups:
|
||||||
for model in model_group:
|
for model in model_group:
|
||||||
with self.subTest(model=model):
|
with self.subTest(model=model):
|
||||||
process = launch_server(self.base_url, model, is_fp8, is_tp2)
|
process = popen_launch_server_wrapper(
|
||||||
|
self.base_url, model, is_fp8, is_tp2
|
||||||
|
)
|
||||||
|
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import signal
|
|||||||
import subprocess
|
import subprocess
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from test_nightly_gsm8k_eval import launch_server, parse_models
|
from test_nightly_gsm8k_eval import parse_models, popen_launch_server_wrapper
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -93,7 +93,7 @@ class TestNightlyHumanEval(unittest.TestCase):
|
|||||||
# NOTE: only Llama for now
|
# NOTE: only Llama for now
|
||||||
if "Llama" in model:
|
if "Llama" in model:
|
||||||
with self.subTest(model=model):
|
with self.subTest(model=model):
|
||||||
self.process = launch_server(
|
self.process = popen_launch_server_wrapper(
|
||||||
self.base_url, model, is_fp8, is_tp2
|
self.base_url, model, is_fp8, is_tp2
|
||||||
)
|
)
|
||||||
self.run_evalplus(model)
|
self.run_evalplus(model)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
python3 -m unittest test_srt_engine.TestSRTEngine.test_3_sync_streaming_combination
|
python3 -m unittest test_srt_engine.TestSRTEngine.test_4_sync_async_stream_combination
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -44,83 +44,29 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
print(out2)
|
print(out2)
|
||||||
self.assertEqual(out1, out2)
|
self.assertEqual(out1, out2)
|
||||||
|
|
||||||
def test_2_engine_multiple_generate(self):
|
def test_2_engine_runtime_encode_consistency(self):
|
||||||
|
prompt = "Today is a sunny day and I like"
|
||||||
|
model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42)
|
||||||
|
out1 = torch.tensor(engine.encode(prompt)["embedding"])
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42)
|
||||||
|
out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"])
|
||||||
|
runtime.shutdown()
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))
|
||||||
|
|
||||||
|
def test_3_engine_token_ids_consistency(self):
|
||||||
# just to ensure there is no issue running multiple generate calls
|
# just to ensure there is no issue running multiple generate calls
|
||||||
prompt = "Today is a sunny day and I like"
|
prompt = "Today is a sunny day and I like"
|
||||||
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
engine = sgl.Engine(model_path=model_path, random_seed=42)
|
|
||||||
engine.generate(prompt, sampling_params)
|
|
||||||
engine.generate(prompt, sampling_params)
|
|
||||||
engine.shutdown()
|
|
||||||
|
|
||||||
def test_3_sync_streaming_combination(self):
|
|
||||||
|
|
||||||
prompt = "AI safety is..."
|
|
||||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
|
||||||
|
|
||||||
async def async_streaming(engine):
|
|
||||||
|
|
||||||
generator = await engine.async_generate(
|
|
||||||
prompt, sampling_params, stream=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async for output in generator:
|
|
||||||
print(output["text"], end="", flush=True)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Create an LLM.
|
|
||||||
llm = sgl.Engine(
|
|
||||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. sync + non streaming
|
|
||||||
print("\n\n==== 1. sync + non streaming ====")
|
|
||||||
output = llm.generate(prompt, sampling_params)
|
|
||||||
|
|
||||||
print(output["text"])
|
|
||||||
|
|
||||||
# 2. sync + streaming
|
|
||||||
print("\n\n==== 2. sync + streaming ====")
|
|
||||||
output_generator = llm.generate(prompt, sampling_params, stream=True)
|
|
||||||
for output in output_generator:
|
|
||||||
print(output["text"], end="", flush=True)
|
|
||||||
print()
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
# 3. async + non_streaming
|
|
||||||
print("\n\n==== 3. async + non streaming ====")
|
|
||||||
output = loop.run_until_complete(llm.async_generate(prompt, sampling_params))
|
|
||||||
print(output["text"])
|
|
||||||
|
|
||||||
# 4. async + streaming
|
|
||||||
print("\n\n==== 4. async + streaming ====")
|
|
||||||
loop.run_until_complete(async_streaming(llm))
|
|
||||||
|
|
||||||
llm.shutdown()
|
|
||||||
|
|
||||||
def test_4_gsm8k(self):
|
|
||||||
|
|
||||||
args = SimpleNamespace(
|
|
||||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
||||||
local_data_path=None,
|
|
||||||
num_shots=5,
|
|
||||||
num_questions=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = run_eval(args)
|
|
||||||
self.assertGreater(metrics["accuracy"], 0.3)
|
|
||||||
|
|
||||||
def test_5_prompt_input_ids_consistency(self):
|
|
||||||
prompt = "The capital of UK is"
|
|
||||||
|
|
||||||
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
||||||
engine = sgl.Engine(
|
engine = sgl.Engine(
|
||||||
model_path=model_path, random_seed=42, disable_radix_cache=True
|
model_path=model_path, random_seed=42, disable_radix_cache=True
|
||||||
)
|
)
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
|
||||||
out1 = engine.generate(prompt, sampling_params)["text"]
|
out1 = engine.generate(prompt, sampling_params)["text"]
|
||||||
|
|
||||||
tokenizer = get_tokenizer(model_path)
|
tokenizer = get_tokenizer(model_path)
|
||||||
@@ -138,21 +84,69 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
print(out2)
|
print(out2)
|
||||||
self.assertEqual(out1, out2)
|
self.assertEqual(out1, out2)
|
||||||
|
|
||||||
def test_6_engine_runtime_encode_consistency(self):
|
def test_4_sync_async_stream_combination(self):
|
||||||
prompt = "Today is a sunny day and I like"
|
prompt = "AI safety is"
|
||||||
model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||||
|
|
||||||
engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42)
|
# Create an LLM.
|
||||||
out1 = torch.tensor(engine.encode(prompt)["embedding"])
|
llm = sgl.Engine(
|
||||||
engine.shutdown()
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
)
|
||||||
|
|
||||||
runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42)
|
if True:
|
||||||
out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"])
|
# 1. sync + non streaming
|
||||||
runtime.shutdown()
|
print("\n\n==== 1. sync + non streaming ====")
|
||||||
|
output = llm.generate(prompt, sampling_params)
|
||||||
|
print(output["text"])
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))
|
# 2. sync + streaming
|
||||||
|
print("\n\n==== 2. sync + streaming ====")
|
||||||
|
output_generator = llm.generate(prompt, sampling_params, stream=True)
|
||||||
|
offset = 0
|
||||||
|
for output in output_generator:
|
||||||
|
print(output["text"][offset:], end="", flush=True)
|
||||||
|
offset = len(output["text"])
|
||||||
|
print()
|
||||||
|
|
||||||
def test_7_engine_cpu_offload(self):
|
if True:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
# 3. async + non_streaming
|
||||||
|
print("\n\n==== 3. async + non streaming ====")
|
||||||
|
output = loop.run_until_complete(
|
||||||
|
llm.async_generate(prompt, sampling_params)
|
||||||
|
)
|
||||||
|
print(output["text"])
|
||||||
|
|
||||||
|
# 4. async + streaming
|
||||||
|
async def async_streaming(engine):
|
||||||
|
generator = await engine.async_generate(
|
||||||
|
prompt, sampling_params, stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
async for output in generator:
|
||||||
|
print(output["text"][offset:], end="", flush=True)
|
||||||
|
offset = len(output["text"])
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("\n\n==== 4. async + streaming ====")
|
||||||
|
loop.run_until_complete(async_streaming(llm))
|
||||||
|
|
||||||
|
llm.shutdown()
|
||||||
|
|
||||||
|
def test_5_gsm8k(self):
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
local_data_path=None,
|
||||||
|
num_shots=5,
|
||||||
|
num_questions=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.3)
|
||||||
|
|
||||||
|
def test_6_engine_cpu_offload(self):
|
||||||
prompt = "Today is a sunny day and I like"
|
prompt = "Today is a sunny day and I like"
|
||||||
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
@@ -182,7 +176,7 @@ class TestSRTEngine(unittest.TestCase):
|
|||||||
print(out2)
|
print(out2)
|
||||||
self.assertEqual(out1, out2)
|
self.assertEqual(out1, out2)
|
||||||
|
|
||||||
def test_8_engine_offline_throughput(self):
|
def test_7_engine_offline_throughput(self):
|
||||||
server_args = ServerArgs(
|
server_args = ServerArgs(
|
||||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user