diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 23c98ea93..60551b2c1 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -91,7 +91,7 @@ Here is how you can do it: ```python from sglang.srt.models.registry import ModelRegistry -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server # for a single model, you can add it to the registry ModelRegistry.models[model_name] = model_class diff --git a/python/sglang/api.py b/python/sglang/api.py index a9c5fa9da..7ef306380 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -40,7 +40,7 @@ def Runtime(*args, **kwargs): def Engine(*args, **kwargs): # Avoid importing unnecessary dependency - from sglang.srt.server import Engine + from sglang.srt.entrypoints.engine import Engine return Engine(*args, **kwargs) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 6b31ac40e..b0a715e61 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -28,7 +28,7 @@ from sglang.bench_serving import ( set_ulimit, ) from sglang.lang.backend.runtime_endpoint import Runtime -from sglang.srt.server import Engine +from sglang.srt.entrypoints.engine import Engine from sglang.srt.server_args import ServerArgs diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 99fba8be9..473f478ad 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -57,12 +57,12 @@ import torch import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server import _set_envs_and_config from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index 01cc561e1..5f0759a7c 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -22,7 +22,7 @@ from typing import Tuple import numpy as np import requests -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index c139db6f0..01f10b9f0 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -351,7 +351,7 @@ class Runtime: """See the arguments in server_args.py::ServerArgs""" # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run # client code without installing SRT server and its dependency if they want. - from sglang.srt.server import launch_server + from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 6b0c25711..caae7b0f6 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -3,7 +3,7 @@ import os import sys -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_process_tree diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py new file mode 100644 index 000000000..310e92c23 --- /dev/null +++ b/python/sglang/srt/entrypoints/engine.py @@ -0,0 +1,449 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The entry point of inference server. (SRT = SGLang Runtime) + +This file implements python APIs for the inference engine. +""" + +import asyncio +import atexit +import dataclasses +import logging +import multiprocessing as mp +import os +import signal +import threading +from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import torch +import uvloop + +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.detokenizer_manager import run_detokenizer_process +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, +) +from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import ( + MultiprocessingSerializer, + assert_pkg_version, + configure_logger, + kill_process_tree, + maybe_set_triton_cache_manager, + prepare_model_and_tokenizer, + set_prometheus_multiproc_dir, + set_ulimit, +) +from sglang.version import __version__ + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +class Engine: + """ + The entry point to the inference engine. + + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + + def __init__(self, **kwargs): + """ + The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`. + Please refer to `ServerArgs` for the documentation. + """ + if "server_args" in kwargs: + # Directly load server_args + server_args = kwargs["server_args"] + else: + # Construct server_args from kwargs + if "log_level" not in kwargs: + # Do not print logs by default + kwargs["log_level"] = "error" + server_args = ServerArgs(**kwargs) + + # Shutdown the subprocesses automatically when the program exists + atexit.register(self.shutdown) + + # Launch subprocesses + tokenizer_manager, scheduler_info = _launch_subprocesses( + server_args=server_args + ) + self.tokenizer_manager = tokenizer_manager + self.scheduler_info = scheduler_info + + def generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, + ) -> Union[Dict, Iterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + custom_logit_processor=custom_logit_processor, + stream=stream, + ) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream: + + def generator_wrapper(): + while True: + try: + chunk = loop.run_until_complete(generator.__anext__()) + yield chunk + except StopAsyncIteration: + break + + return generator_wrapper() + else: + ret = loop.run_until_complete(generator.__anext__()) + return ret + + async def async_generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, + ) -> Union[Dict, AsyncIterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + custom_logit_processor=custom_logit_processor, + ) + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream is True: + return generator + else: + return await generator.__anext__() + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ) -> Dict: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. + Please refer to `EmbeddingReqInput` for the documentation. + """ + + obj = EmbeddingReqInput(text=prompt) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + ret = loop.run_until_complete(generator.__anext__()) + return ret + + def shutdown(self): + """Shutdown the engine""" + kill_process_tree(os.getpid(), include_parent=False) + + def start_profile(self): + self.tokenizer_manager.start_profile() + + def stop_profile(self): + self.tokenizer_manager.stop_profile() + + def get_server_info(self): + return { + **dataclasses.asdict(self.tokenizer_manager.server_args), # server args + **self.scheduler_info, + "version": __version__, + } + + def init_weights_update_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + ): + """Initialize parameter update group.""" + obj = InitWeightsUpdateGroupReqInput( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.init_weights_update_group(obj, None) + ) + + def update_weights_from_distributed(self, name: str, dtype, shape): + """Update weights from distributed source.""" + obj = UpdateWeightsFromDistributedReqInput( + name=name, + dtype=dtype, + shape=shape, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_distributed(obj, None) + ) + + def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): + """Update weights from distributed source.""" + obj = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors) + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_tensor(obj, None) + ) + + def get_weights_by_name(self, name: str, truncate_size: int = 100): + """Get weights by parameter name.""" + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.get_weights_by_name(obj, None) + ) + + def release_memory_occupation(self): + """Release GPU occupation temporarily.""" + obj = ReleaseMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.release_memory_occupation(obj, None) + ) + + def resume_memory_occupation(self): + """Resume GPU occupation.""" + obj = ResumeMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.resume_memory_occupation(obj, None) + ) + + +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer", + "0.1.6", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + def sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child proces. It usually means the child failed." + ) + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + +def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]: + """ + Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + """ + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + _set_envs_and_config(server_args) + + # Allocate ports for inter-process communications + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # If using model from www.modelscope.cn, first download the model. + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + if server_args.dp_size == 1: + # Launch tensor parallel scheduler processes + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, None, writer), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Launch the data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + if server_args.node_rank >= 1: + # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, + # so they can just wait here. + + for reader in scheduler_pipe_readers: + data = reader.recv() + assert data["status"] == "ready" + + if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": + # When using `Engine` as a Python API, we don't want to block here. + return + + for proc in scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) + return + + # Launch detokenizer process + detoken_proc = mp.Process( + target=run_detokenizer_process, + args=( + server_args, + port_args, + ), + ) + detoken_proc.start() + + # Launch tokenizer process + tokenizer_manager = TokenizerManager(server_args, port_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + + # Wait for the model to finish loading + scheduler_infos = [] + for i in range(len(scheduler_pipe_readers)): + try: + data = scheduler_pipe_readers[i].recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + + # Assume all schedulers have the same scheduler_info + scheduler_info = scheduler_infos[0] + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + return tokenizer_manager, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py new file mode 100644 index 000000000..0ebce1a85 --- /dev/null +++ b/python/sglang/srt/entrypoints/http_server.py @@ -0,0 +1,579 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The entry point of inference server. (SRT = SGLang Runtime) + +This file implements HTTP APIs for the inferenc engine via fastapi. +""" + +import asyncio +import dataclasses +import logging +import multiprocessing as multiprocessing +import os +import threading +import time +from http import HTTPStatus +from typing import AsyncIterator, Dict, Optional + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import orjson +import requests +import uvicorn +import uvloop +from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse, Response, StreamingResponse + +from sglang.srt.entrypoints.engine import _launch_subprocesses +from sglang.srt.managers.io_struct import ( + CloseSessionReqInput, + ConfigureLoggingReq, + EmbeddingReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + OpenSessionReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightFromDiskReqInput, + UpdateWeightsFromDistributedReqInput, +) +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.metrics.func_timer import enable_func_timer +from sglang.srt.openai_api.adapter import ( + v1_batches, + v1_cancel_batch, + v1_chat_completions, + v1_completions, + v1_delete_file, + v1_embeddings, + v1_files_create, + v1_retrieve_batch, + v1_retrieve_file, + v1_retrieve_file_content, +) +from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + add_api_key_middleware, + add_prometheus_middleware, + delete_directory, + kill_process_tree, + set_uvicorn_logging_configs, +) +from sglang.utils import get_exception_traceback +from sglang.version import __version__ + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# Fast API +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Store global states +@dataclasses.dataclass +class _GlobalState: + tokenizer_manager: TokenizerManager + scheduler_info: Dict + + +_global_state: Optional[_GlobalState] = None + + +def set_global_state(global_state: _GlobalState): + global _global_state + _global_state = global_state + + +##### Native API endpoints ##### + + +@app.get("/health") +async def health() -> Response: + """Check the health of the http server.""" + return Response(status_code=200) + + +@app.get("/health_generate") +async def health_generate(request: Request) -> Response: + """Check the health of the inference server by generating one token.""" + + sampling_params = {"max_new_tokens": 1, "temperature": 0.7} + + if _global_state.tokenizer_manager.is_generation: + gri = GenerateReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) + else: + gri = EmbeddingReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) + + try: + async for _ in _global_state.tokenizer_manager.generate_request(gri, request): + break + return Response(status_code=200) + except Exception as e: + logger.exception(e) + return Response(status_code=503) + + +@app.get("/get_model_info") +async def get_model_info(): + """Get the model information.""" + result = { + "model_path": _global_state.tokenizer_manager.model_path, + "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, + "is_generation": _global_state.tokenizer_manager.is_generation, + } + return result + + +@app.get("/get_server_info") +async def get_server_info(): + return { + **dataclasses.asdict(_global_state.tokenizer_manager.server_args), + **_global_state.scheduler_info, + "version": __version__, + } + + +# fastapi implicitly converts json in the request to obj (dataclass) +@app.api_route("/generate", methods=["POST", "PUT"]) +async def generate_request(obj: GenerateReqInput, request: Request): + """Handle a generate request.""" + if obj.stream: + + async def stream_results() -> AsyncIterator[bytes]: + try: + async for out in _global_state.tokenizer_manager.generate_request( + obj, request + ): + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + except ValueError as e: + out = {"error": {"message": str(e)}} + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + background=_global_state.tokenizer_manager.create_abort_task(obj), + ) + else: + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + logger.error(f"Error: {e}") + return _create_error_response(e) + + +@app.api_route("/encode", methods=["POST", "PUT"]) +async def encode_request(obj: EmbeddingReqInput, request: Request): + """Handle an embedding request.""" + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.api_route("/classify", methods=["POST", "PUT"]) +async def classify_request(obj: EmbeddingReqInput, request: Request): + """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.post("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + _global_state.tokenizer_manager.flush_cache() + return Response( + content="Cache flushed.\nPlease check backend logs for more details. " + "(When there are running or waiting requests, the operation will not be performed.)\n", + status_code=200, + ) + + +@app.api_route("/start_profile", methods=["GET", "POST"]) +async def start_profile_async(): + """Start profiling.""" + _global_state.tokenizer_manager.start_profile() + return Response( + content="Start profiling.\n", + status_code=200, + ) + + +@app.api_route("/stop_profile", methods=["GET", "POST"]) +async def stop_profile_async(): + """Stop profiling.""" + _global_state.tokenizer_manager.stop_profile() + return Response( + content="Stop profiling. This will take some time.\n", + status_code=200, + ) + + +@app.post("/update_weights_from_disk") +async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): + """Update the weights from disk in-place without re-launching the server.""" + success, message = await _global_state.tokenizer_manager.update_weights_from_disk( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse( + content, + status_code=HTTPStatus.OK, + ) + else: + return ORJSONResponse( + content, + status_code=HTTPStatus.BAD_REQUEST, + ) + + +@app.post("/init_weights_update_group") +async def init_weights_update_group( + obj: InitWeightsUpdateGroupReqInput, request: Request +): + """Initialize the parameter update group.""" + success, message = await _global_state.tokenizer_manager.init_weights_update_group( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed( + obj: UpdateWeightsFromDistributedReqInput, request: Request +): + """Update model parameter from distributed online.""" + success, message = ( + await _global_state.tokenizer_manager.update_weights_from_distributed( + obj, request + ) + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) +async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): + """Get model parameter by name.""" + try: + ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request) + if ret is None: + return _create_error_response("Get parameter by name failed") + else: + return ORJSONResponse(ret, status_code=200) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/release_memory_occupation", methods=["GET", "POST"]) +async def release_memory_occupation( + obj: ReleaseMemoryOccupationReqInput, request: Request +): + """Release GPU occupation temporarily""" + try: + await _global_state.tokenizer_manager.release_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) +async def resume_memory_occupation( + obj: ResumeMemoryOccupationReqInput, request: Request +): + """Resume GPU occupation""" + try: + await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/open_session", methods=["GET", "POST"]) +async def open_session(obj: OpenSessionReqInput, request: Request): + """Open a session, and return its unique session id.""" + try: + session_id = await _global_state.tokenizer_manager.open_session(obj, request) + if session_id is None: + raise Exception( + "Failed to open the session. Check if a session with the same id is still open." + ) + return session_id + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/close_session", methods=["GET", "POST"]) +async def close_session(obj: CloseSessionReqInput, request: Request): + """Close the session""" + try: + await _global_state.tokenizer_manager.close_session(obj, request) + return Response(status_code=200) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/configure_logging", methods=["GET", "POST"]) +async def configure_logging(obj: ConfigureLoggingReq, request: Request): + """Close the session""" + _global_state.tokenizer_manager.configure_logging(obj) + return Response(status_code=200) + + +##### OpenAI-compatible API endpoints ##### + + +@app.post("/v1/completions") +async def openai_v1_completions(raw_request: Request): + return await v1_completions(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/chat/completions") +async def openai_v1_chat_completions(raw_request: Request): + return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/embeddings", response_class=ORJSONResponse) +async def openai_v1_embeddings(raw_request: Request): + response = await v1_embeddings(_global_state.tokenizer_manager, raw_request) + return response + + +@app.get("/v1/models", response_class=ORJSONResponse) +def available_models(): + """Show available models.""" + served_model_names = [_global_state.tokenizer_manager.served_model_name] + model_cards = [] + for served_model_name in served_model_names: + model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) + return ModelList(data=model_cards) + + +@app.post("/v1/files") +async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): + return await v1_files_create( + file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth + ) + + +@app.delete("/v1/files/{file_id}") +async def delete_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/delete + return await v1_delete_file(file_id) + + +@app.post("/v1/batches") +async def openai_v1_batches(raw_request: Request): + return await v1_batches(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/batches/{batch_id}/cancel") +async def cancel_batches(batch_id: str): + # https://platform.openai.com/docs/api-reference/batch/cancel + return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id) + + +@app.get("/v1/batches/{batch_id}") +async def retrieve_batch(batch_id: str): + return await v1_retrieve_batch(batch_id) + + +@app.get("/v1/files/{file_id}") +async def retrieve_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve + return await v1_retrieve_file(file_id) + + +@app.get("/v1/files/{file_id}/content") +async def retrieve_file_content(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve-contents + return await v1_retrieve_file_content(file_id) + + +def _create_error_response(e): + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +def launch_server( + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, +): + """ + Launch SRT (SGLang Runtime) Server. + + The SRT server consists of an HTTP server and an SRT engine. + + - HTTP server: A FastAPI server that routes requests to the engine. + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) + set_global_state( + _GlobalState( + tokenizer_manager=tokenizer_manager, + scheduler_info=scheduler_info, + ) + ) + + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) + + # Add prometheus middleware + if server_args.enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() + + # Send a warmup request + t = threading.Thread( + target=_wait_and_warmup, + args=( + server_args, + pipe_finish_writer, + _global_state.tokenizer_manager.image_token_id, + ), + ) + t.start() + + try: + # Update logging configs + set_uvicorn_logging_configs() + + # Listen for HTTP requests + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + finally: + t.join() + + +def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): + headers = {} + url = server_args.url() + if server_args.api_key: + headers["Authorization"] = f"Bearer {server_args.api_key}" + + # Wait until the server is launched + success = False + for _ in range(120): + time.sleep(1) + try: + res = requests.get(url + "/get_model_info", timeout=5, headers=headers) + assert res.status_code == 200, f"{res=}, {res.text=}" + success = True + break + except (AssertionError, requests.exceptions.RequestException): + last_traceback = get_exception_traceback() + pass + + if not success: + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_process_tree(os.getpid()) + return + + model_info = res.json() + + # Send a warmup request + request_name = "/generate" if model_info["is_generation"] else "/encode" + max_new_tokens = 8 if model_info["is_generation"] else 1 + json_data = { + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + } + if server_args.skip_tokenizer_init: + json_data["input_ids"] = [10, 11, 12] + else: + json_data["text"] = "The capital city of France is" + + try: + for _ in range(server_args.dp_size): + res = requests.post( + url + request_name, + json=json_data, + headers=headers, + timeout=600, + ) + assert res.status_code == 200, f"{res}" + except Exception: + last_traceback = get_exception_traceback() + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_process_tree(os.getpid()) + return + + # Debug print + # logger.info(f"{res.json()=}") + + logger.info("The server is fired up and ready to roll!") + if pipe_finish_writer is not None: + pipe_finish_writer.send("ready") + + if server_args.delete_ckpt_after_loading: + delete_directory(server_args.model_path) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5a803dd99..918323983 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,7 +22,6 @@ from enum import Enum from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason -from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_params import SamplingParams diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d6178a959..162f10624 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -176,7 +176,7 @@ class TokenizerManager: ) # Store states - self.to_create_loop = True + self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} self.dump_requests_folder = "" # By default do not dump self.dump_requests_threshold = 1000 @@ -684,7 +684,6 @@ class TokenizerManager: async def close_session( self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None ): - assert not self.to_create_loop, "close session should not be the first request" await self.send_to_scheduler.send_pyobj(obj) def configure_logging(self, obj: ConfigureLoggingReq): @@ -713,10 +712,10 @@ class TokenizerManager: return background_tasks def auto_create_handle_loop(self): - if not self.to_create_loop: + if self.no_create_loop: return - self.to_create_loop = False + self.no_create_loop = True loop = asyncio.get_event_loop() self.asyncio_tasks.add( loop.create_task(print_exception_wrapper(self.handle_loop)) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 0b4d9c372..8b0c56186 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -11,949 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -""" -The entry point of inference server. -SRT = SGLang Runtime. -""" -import asyncio -import atexit -import dataclasses -import json -import logging -import multiprocessing as mp -import os -import signal -import threading -import time -from http import HTTPStatus -from typing import AsyncIterator, Dict, List, Optional, Tuple, Union - -import torch - -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter - -# Fix a bug of Python threading -setattr(threading, "_register_atexit", lambda *args, **kwargs: None) - -import aiohttp -import orjson -import requests -import uvicorn -import uvloop -from fastapi import FastAPI, File, Form, Request, UploadFile -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import ORJSONResponse, Response, StreamingResponse - -from sglang.srt.managers.data_parallel_controller import ( - run_data_parallel_controller_process, -) -from sglang.srt.managers.detokenizer_manager import run_detokenizer_process -from sglang.srt.managers.io_struct import ( - CloseSessionReqInput, - ConfigureLoggingReq, - EmbeddingReqInput, - GenerateReqInput, - GetWeightsByNameReqInput, - InitWeightsUpdateGroupReqInput, - OpenSessionReqInput, - ReleaseMemoryOccupationReqInput, - ResumeMemoryOccupationReqInput, - UpdateWeightFromDiskReqInput, - UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, -) -from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency -from sglang.srt.openai_api.adapter import ( - load_chat_template_for_openai_api, - v1_batches, - v1_cancel_batch, - v1_chat_completions, - v1_completions, - v1_delete_file, - v1_embeddings, - v1_files_create, - v1_retrieve_batch, - v1_retrieve_file, - v1_retrieve_file_content, -) -from sglang.srt.openai_api.protocol import ModelCard, ModelList -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - MultiprocessingSerializer, - add_api_key_middleware, - add_prometheus_middleware, - assert_pkg_version, - configure_logger, - delete_directory, - kill_process_tree, - maybe_set_triton_cache_manager, - prepare_model_and_tokenizer, - set_prometheus_multiproc_dir, - set_ulimit, - set_uvicorn_logging_configs, -) -from sglang.utils import get_exception_traceback -from sglang.version import __version__ - -logger = logging.getLogger(__name__) - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -# Fast API -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -tokenizer_manager: TokenizerManager = None -scheduler_info: Dict = None - - -##### Native API endpoints ##### - - -@app.get("/health") -async def health() -> Response: - """Check the health of the http server.""" - return Response(status_code=200) - - -@app.get("/health_generate") -async def health_generate(request: Request) -> Response: - """Check the health of the inference server by generating one token.""" - - sampling_params = {"max_new_tokens": 1, "temperature": 0.7} - - if tokenizer_manager.is_generation: - gri = GenerateReqInput( - input_ids=[0], sampling_params=sampling_params, log_metrics=False - ) - else: - gri = EmbeddingReqInput( - input_ids=[0], sampling_params=sampling_params, log_metrics=False - ) - - try: - async for _ in tokenizer_manager.generate_request(gri, request): - break - return Response(status_code=200) - except Exception as e: - logger.exception(e) - return Response(status_code=503) - - -@app.get("/get_model_info") -async def get_model_info(): - """Get the model information.""" - result = { - "model_path": tokenizer_manager.model_path, - "tokenizer_path": tokenizer_manager.server_args.tokenizer_path, - "is_generation": tokenizer_manager.is_generation, - } - return result - - -@app.get("/get_server_info") -async def get_server_info(): - return { - **dataclasses.asdict(tokenizer_manager.server_args), - **scheduler_info, - "version": __version__, - } - - -# fastapi implicitly converts json in the request to obj (dataclass) -@app.api_route("/generate", methods=["POST", "PUT"]) -@time_func_latency -async def generate_request(obj: GenerateReqInput, request: Request): - """Handle a generate request.""" - if obj.stream: - - async def stream_results() -> AsyncIterator[bytes]: - try: - async for out in tokenizer_manager.generate_request(obj, request): - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - except ValueError as e: - out = {"error": {"message": str(e)}} - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - yield b"data: [DONE]\n\n" - - return StreamingResponse( - stream_results(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(obj), - ) - else: - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - logger.error(f"Error: {e}") - return _create_error_response(e) - - -@app.api_route("/encode", methods=["POST", "PUT"]) -@time_func_latency -async def encode_request(obj: EmbeddingReqInput, request: Request): - """Handle an embedding request.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return _create_error_response(e) - - -@app.api_route("/classify", methods=["POST", "PUT"]) -@time_func_latency -async def classify_request(obj: EmbeddingReqInput, request: Request): - """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return _create_error_response(e) - - -@app.post("/flush_cache") -async def flush_cache(): - """Flush the radix cache.""" - tokenizer_manager.flush_cache() - return Response( - content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", - status_code=200, - ) - - -@app.api_route("/start_profile", methods=["GET", "POST"]) -async def start_profile_async(): - """Start profiling.""" - tokenizer_manager.start_profile() - return Response( - content="Start profiling.\n", - status_code=200, - ) - - -@app.api_route("/stop_profile", methods=["GET", "POST"]) -async def stop_profile_async(): - """Stop profiling.""" - tokenizer_manager.stop_profile() - return Response( - content="Stop profiling. This will take some time.\n", - status_code=200, - ) - - -@app.post("/update_weights_from_disk") -@time_func_latency -async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): - """Update the weights from disk in-place without re-launching the server.""" - success, message = await tokenizer_manager.update_weights_from_disk(obj, request) - content = {"success": success, "message": message} - if success: - return ORJSONResponse( - content, - status_code=HTTPStatus.OK, - ) - else: - return ORJSONResponse( - content, - status_code=HTTPStatus.BAD_REQUEST, - ) - - -@app.post("/init_weights_update_group") -async def init_weights_update_group( - obj: InitWeightsUpdateGroupReqInput, request: Request -): - """Initialize the parameter update group.""" - success, message = await tokenizer_manager.init_weights_update_group(obj, request) - content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) - - -@app.post("/update_weights_from_distributed") -async def update_weights_from_distributed( - obj: UpdateWeightsFromDistributedReqInput, request: Request -): - """Update model parameter from distributed online.""" - success, message = await tokenizer_manager.update_weights_from_distributed( - obj, request - ) - content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) - - -@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) -async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): - """Get model parameter by name.""" - try: - ret = await tokenizer_manager.get_weights_by_name(obj, request) - if ret is None: - return _create_error_response("Get parameter by name failed") - else: - return ORJSONResponse(ret, status_code=200) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/release_memory_occupation", methods=["GET", "POST"]) -async def release_memory_occupation( - obj: ReleaseMemoryOccupationReqInput, request: Request -): - """Release GPU occupation temporarily""" - try: - await tokenizer_manager.release_memory_occupation(obj, request) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) -async def resume_memory_occupation( - obj: ResumeMemoryOccupationReqInput, request: Request -): - """Resume GPU occupation""" - try: - await tokenizer_manager.resume_memory_occupation(obj, request) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/open_session", methods=["GET", "POST"]) -async def open_session(obj: OpenSessionReqInput, request: Request): - """Open a session, and return its unique session id.""" - try: - session_id = await tokenizer_manager.open_session(obj, request) - if session_id is None: - raise Exception( - "Failed to open the session. Check if a session with the same id is still open." - ) - return session_id - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/close_session", methods=["GET", "POST"]) -async def close_session(obj: CloseSessionReqInput, request: Request): - """Close the session""" - try: - await tokenizer_manager.close_session(obj, request) - return Response(status_code=200) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/configure_logging", methods=["GET", "POST"]) -async def configure_logging(obj: ConfigureLoggingReq, request: Request): - """Close the session""" - tokenizer_manager.configure_logging(obj) - return Response(status_code=200) - - -##### OpenAI-compatible API endpoints ##### - - -@app.post("/v1/completions") -@time_func_latency -async def openai_v1_completions(raw_request: Request): - return await v1_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/chat/completions") -@time_func_latency -async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/embeddings", response_class=ORJSONResponse) -@time_func_latency -async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(tokenizer_manager, raw_request) - return response - - -@app.get("/v1/models", response_class=ORJSONResponse) -def available_models(): - """Show available models.""" - served_model_names = [tokenizer_manager.served_model_name] - model_cards = [] - for served_model_name in served_model_names: - model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) - return ModelList(data=model_cards) - - -@app.post("/v1/files") -async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, tokenizer_manager.server_args.file_storage_pth - ) - - -@app.delete("/v1/files/{file_id}") -async def delete_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/delete - return await v1_delete_file(file_id) - - -@app.post("/v1/batches") -async def openai_v1_batches(raw_request: Request): - return await v1_batches(tokenizer_manager, raw_request) - - -@app.post("/v1/batches/{batch_id}/cancel") -async def cancel_batches(batch_id: str): - # https://platform.openai.com/docs/api-reference/batch/cancel - return await v1_cancel_batch(tokenizer_manager, batch_id) - - -@app.get("/v1/batches/{batch_id}") -async def retrieve_batch(batch_id: str): - return await v1_retrieve_batch(batch_id) - - -@app.get("/v1/files/{file_id}") -async def retrieve_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve - return await v1_retrieve_file(file_id) - - -@app.get("/v1/files/{file_id}/content") -async def retrieve_file_content(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve-contents - return await v1_retrieve_file_content(file_id) - - -def _create_error_response(e): - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -def launch_engine( - server_args: ServerArgs, -): - """ - Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. - """ - - global tokenizer_manager - global scheduler_info - - # Configure global environment - configure_logger(server_args) - server_args.check_server_args() - _set_envs_and_config(server_args) - - # Allocate ports for inter-process communications - port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") - - # If using model from www.modelscope.cn, first download the model. - server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( - server_args.model_path, server_args.tokenizer_path - ) - - scheduler_procs = [] - if server_args.dp_size == 1: - # Launch tensor parallel scheduler processes - memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=server_args.enable_memory_saver - ) - - scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes - tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), - ) - for tp_rank in tp_rank_range: - reader, writer = mp.Pipe(duplex=False) - gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, None, writer), - ) - with memory_saver_adapter.configure_subprocess(): - proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - else: - # Launch the data parallel controller - reader, writer = mp.Pipe(duplex=False) - scheduler_pipe_readers = [reader] - proc = mp.Process( - target=run_data_parallel_controller_process, - args=(server_args, port_args, writer), - ) - proc.start() - scheduler_procs.append(proc) - - if server_args.node_rank >= 1: - # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, - # so they can just wait here. - - for reader in scheduler_pipe_readers: - data = reader.recv() - assert data["status"] == "ready" - - if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": - # When using `Engine` as a Python API, we don't want to block here. - return - - for proc in scheduler_procs: - proc.join() - logger.error( - f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" - ) - return - - # Launch detokenizer process - detoken_proc = mp.Process( - target=run_detokenizer_process, - args=( - server_args, - port_args, - ), - ) - detoken_proc.start() - - # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - - # Wait for model to finish loading - scheduler_infos = [] - for i in range(len(scheduler_pipe_readers)): - try: - data = scheduler_pipe_readers[i].recv() - except EOFError as e: - logger.exception(e) - logger.error( - f"Rank {i} scheduler is dead. Please check if there are relevant logs." - ) - scheduler_procs[i].join() - logger.error(f"Exit code: {scheduler_procs[i].exitcode}") - raise - - if data["status"] != "ready": - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - scheduler_infos.append(data) - - # Assume all schedulers have same scheduler_info - scheduler_info = scheduler_infos[0] - tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] - - -def launch_server( - server_args: ServerArgs, - pipe_finish_writer: Optional[mp.connection.Connection] = None, -): - """ - Launch SRT (SGLang Runtime) Server - - The SRT server consists of an HTTP server and the SRT engine. - - 1. HTTP server: A FastAPI server that routes requests to the engine. - 2. SRT engine: - 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. - 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. - 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. - - Note: - 1. The HTTP server and TokenizerManager both run in the main process. - 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. - """ - launch_engine(server_args=server_args) - - # Add api key authorization - if server_args.api_key: - add_api_key_middleware(app, server_args.api_key) - - # Add prometheus middleware - if server_args.enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() - - # Send a warmup request - t = threading.Thread( - target=_wait_and_warmup, - args=( - server_args, - pipe_finish_writer, - tokenizer_manager.image_token_id, - ), - ) - t.start() - - try: - # Update logging configs - set_uvicorn_logging_configs() - - # Listen for HTTP requests - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) - finally: - t.join() - - -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" - - # Set prometheus env vars - if server_args.enable_metrics: - set_prometheus_multiproc_dir() - - # Set ulimit - set_ulimit() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if server_args.attention_backend == "flashinfer": - assert_pkg_version( - "flashinfer", - "0.1.6", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - # Register the signal handler. - # The child processes will send SIGQUIT to this process when any error happens - # This process then clean up the whole process tree - def sigquit_handler(signum, frame): - logger.error( - "Received sigquit from a child proces. It usually means the child failed." - ) - kill_process_tree(os.getpid()) - - signal.signal(signal.SIGQUIT, sigquit_handler) - - # Set mp start method - mp.set_start_method("spawn", force=True) - - -def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): - headers = {} - url = server_args.url() - if server_args.api_key: - headers["Authorization"] = f"Bearer {server_args.api_key}" - - # Wait until the server is launched - success = False - for _ in range(120): - time.sleep(1) - try: - res = requests.get(url + "/get_model_info", timeout=5, headers=headers) - assert res.status_code == 200, f"{res=}, {res.text=}" - success = True - break - except (AssertionError, requests.exceptions.RequestException): - last_traceback = get_exception_traceback() - pass - - if not success: - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - model_info = res.json() - - # Send a warmup request - request_name = "/generate" if model_info["is_generation"] else "/encode" - max_new_tokens = 8 if model_info["is_generation"] else 1 - json_data = { - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - } - if server_args.skip_tokenizer_init: - json_data["input_ids"] = [10, 11, 12] - else: - json_data["text"] = "The capital city of France is" - - try: - for _ in range(server_args.dp_size): - res = requests.post( - url + request_name, - json=json_data, - headers=headers, - timeout=600, - ) - assert res.status_code == 200, f"{res}" - except Exception: - last_traceback = get_exception_traceback() - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - # Debug print - # logger.info(f"{res.json()=}") - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("ready") - - if server_args.delete_ckpt_after_loading: - delete_directory(server_args.model_path) - - -STREAM_END_SYMBOL = b"data: [DONE]" -STREAM_CHUNK_START_SYMBOL = b"data:" - - -class Engine: - """ - SRT Engine without an HTTP server layer. - - This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where - launching the HTTP server adds unnecessary complexity or overhead, - """ - - def __init__(self, log_level: str = "error", *args, **kwargs): - """See the arguments in server_args.py::ServerArgs""" - - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - server_args = ServerArgs(*args, log_level=log_level, **kwargs) - launch_engine(server_args=server_args) - - def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - custom_logit_processor=custom_logit_processor, - ) - - # get the current event loop - loop = asyncio.get_event_loop() - ret = loop.run_until_complete(generate_request(obj, None)) - - if stream is True: - - def generator_wrapper(): - offset = 0 - loop = asyncio.get_event_loop() - generator = ret.body_iterator - while True: - chunk = loop.run_until_complete(generator.__anext__()) - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - # we cannot yield in the scope of generate() because python does not allow yield + return in the same function - # however, it allows to wrap the generator as a subfunction and return - return generator_wrapper() - else: - return ret - - async def async_generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Dict] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - custom_logit_processor: Optional[Union[str, List[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - custom_logit_processor=custom_logit_processor, - ) - - ret = await generate_request(obj, None) - - if stream is True: - generator = ret.body_iterator - - async def generator_wrapper(): - offset = 0 - - while True: - chunk = await generator.__anext__() - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - return generator_wrapper() - else: - return ret - - def shutdown(self): - kill_process_tree(os.getpid(), include_parent=False) - - def get_tokenizer(self): - global tokenizer_manager - - if tokenizer_manager is None: - raise ReferenceError("Tokenizer Manager is not initialized.") - else: - return tokenizer_manager.tokenizer - - def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ): - obj = EmbeddingReqInput(text=prompt) - - # get the current event loop - loop = asyncio.get_event_loop() - return loop.run_until_complete(encode_request(obj, None)) - - def start_profile(self): - tokenizer_manager.start_profile() - - def stop_profile(self): - tokenizer_manager.stop_profile() - - def get_server_info(self): - return { - **dataclasses.asdict(tokenizer_manager.server_args), # server args - **scheduler_info, - "version": __version__, - } - - def init_weights_update_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - ): - """Initialize parameter update group.""" - obj = InitWeightsUpdateGroupReqInput( - master_address=master_address, - master_port=master_port, - rank_offset=rank_offset, - world_size=world_size, - group_name=group_name, - backend=backend, - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.init_weights_update_group(obj, None) - ) - - def update_weights_from_distributed(self, name, dtype, shape): - """Update weights from distributed source.""" - obj = UpdateWeightsFromDistributedReqInput( - name=name, - dtype=dtype, - shape=shape, - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.update_weights_from_distributed(obj, None) - ) - - def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): - """Update weights from distributed source.""" - obj = UpdateWeightsFromTensorReqInput( - serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors) - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.update_weights_from_tensor(obj, None) - ) - - def get_weights_by_name(self, name, truncate_size=100): - """Get weights by parameter name.""" - obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) - loop = asyncio.get_event_loop() - return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None)) - - def release_memory_occupation(self): - """Release GPU occupation temporarily""" - obj = ReleaseMemoryOccupationReqInput() - loop = asyncio.get_event_loop() - loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None)) - - def resume_memory_occupation(self): - """Resume GPU occupation""" - obj = ResumeMemoryOccupationReqInput() - loop = asyncio.get_event_loop() - loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None)) +# Some shortcuts for backward compatbility. +# They will be removed in new versions. +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.entrypoints.http_server import launch_server diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index fc9a97937..bae0fcf2a 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -12,7 +12,6 @@ # limitations under the License. # ============================================================================== -import json import multiprocessing as mp import os from dataclasses import dataclass @@ -22,8 +21,8 @@ import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM +from sglang.srt.entrypoints.engine import Engine from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.server import Engine from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER DEFAULT_PROMPTS = [ diff --git a/sgl-router/py_src/sglang_router/launch_server.py b/sgl-router/py_src/sglang_router/launch_server.py index 2f433269e..93bc2345d 100644 --- a/sgl-router/py_src/sglang_router/launch_server.py +++ b/sgl-router/py_src/sglang_router/launch_server.py @@ -13,7 +13,7 @@ import requests from setproctitle import setproctitle from sglang_router.launch_router import RouterArgs, launch_router -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 69babf795..2837107a1 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -56,7 +56,6 @@ class TestEnableMetrics(unittest.TestCase): "sglang:gen_throughput", "sglang:num_queue_reqs", "sglang:cache_hit_rate", - "sglang:func_latency_seconds", "sglang:prompt_tokens_total", "sglang:generation_tokens_total", "sglang:num_requests_total", diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index 2e379c111..06c83048f 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -45,7 +45,7 @@ def parse_models(model_string): return [model.strip() for model in model_string.split(",") if model.strip()] -def launch_server(base_url, model, is_fp8, is_tp2): +def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2): other_args = ["--log-level-http", "warning", "--trust-remote-code"] if is_fp8: if "Llama-3" in model or "gemma-2" in model: @@ -148,7 +148,9 @@ class TestNightlyGsm8KEval(unittest.TestCase): for model_group, is_fp8, is_tp2 in self.model_groups: for model in model_group: with self.subTest(model=model): - process = launch_server(self.base_url, model, is_fp8, is_tp2) + process = popen_launch_server_wrapper( + self.base_url, model, is_fp8, is_tp2 + ) args = SimpleNamespace( base_url=self.base_url, diff --git a/test/srt/test_nightly_human_eval.py b/test/srt/test_nightly_human_eval.py index 0b682937a..6558b9eff 100644 --- a/test/srt/test_nightly_human_eval.py +++ b/test/srt/test_nightly_human_eval.py @@ -4,7 +4,7 @@ import signal import subprocess import unittest -from test_nightly_gsm8k_eval import launch_server, parse_models +from test_nightly_gsm8k_eval import parse_models, popen_launch_server_wrapper from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -93,7 +93,7 @@ class TestNightlyHumanEval(unittest.TestCase): # NOTE: only Llama for now if "Llama" in model: with self.subTest(model=model): - self.process = launch_server( + self.process = popen_launch_server_wrapper( self.base_url, model, is_fp8, is_tp2 ) self.run_evalplus(model) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 7479b6468..c535d5c06 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -1,6 +1,6 @@ """ Usage: -python3 -m unittest test_srt_engine.TestSRTEngine.test_3_sync_streaming_combination +python3 -m unittest test_srt_engine.TestSRTEngine.test_4_sync_async_stream_combination """ import asyncio @@ -44,83 +44,29 @@ class TestSRTEngine(unittest.TestCase): print(out2) self.assertEqual(out1, out2) - def test_2_engine_multiple_generate(self): + def test_2_engine_runtime_encode_consistency(self): + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + + engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) + out1 = torch.tensor(engine.encode(prompt)["embedding"]) + engine.shutdown() + + runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42) + out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"]) + runtime.shutdown() + + self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) + + def test_3_engine_token_ids_consistency(self): # just to ensure there is no issue running multiple generate calls prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - sampling_params = {"temperature": 0, "max_new_tokens": 8} - engine = sgl.Engine(model_path=model_path, random_seed=42) - engine.generate(prompt, sampling_params) - engine.generate(prompt, sampling_params) - engine.shutdown() - - def test_3_sync_streaming_combination(self): - - prompt = "AI safety is..." - sampling_params = {"temperature": 0.8, "top_p": 0.95} - - async def async_streaming(engine): - - generator = await engine.async_generate( - prompt, sampling_params, stream=True - ) - - async for output in generator: - print(output["text"], end="", flush=True) - print() - - # Create an LLM. - llm = sgl.Engine( - model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - ) - - # 1. sync + non streaming - print("\n\n==== 1. sync + non streaming ====") - output = llm.generate(prompt, sampling_params) - - print(output["text"]) - - # 2. sync + streaming - print("\n\n==== 2. sync + streaming ====") - output_generator = llm.generate(prompt, sampling_params, stream=True) - for output in output_generator: - print(output["text"], end="", flush=True) - print() - - loop = asyncio.get_event_loop() - # 3. async + non_streaming - print("\n\n==== 3. async + non streaming ====") - output = loop.run_until_complete(llm.async_generate(prompt, sampling_params)) - print(output["text"]) - - # 4. async + streaming - print("\n\n==== 4. async + streaming ====") - loop.run_until_complete(async_streaming(llm)) - - llm.shutdown() - - def test_4_gsm8k(self): - - args = SimpleNamespace( - model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, - local_data_path=None, - num_shots=5, - num_questions=200, - ) - - metrics = run_eval(args) - self.assertGreater(metrics["accuracy"], 0.3) - - def test_5_prompt_input_ids_consistency(self): - prompt = "The capital of UK is" - - model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST engine = sgl.Engine( model_path=model_path, random_seed=42, disable_radix_cache=True ) - sampling_params = {"temperature": 0, "max_new_tokens": 8} out1 = engine.generate(prompt, sampling_params)["text"] tokenizer = get_tokenizer(model_path) @@ -138,21 +84,69 @@ class TestSRTEngine(unittest.TestCase): print(out2) self.assertEqual(out1, out2) - def test_6_engine_runtime_encode_consistency(self): - prompt = "Today is a sunny day and I like" - model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + def test_4_sync_async_stream_combination(self): + prompt = "AI safety is" + sampling_params = {"temperature": 0.8, "top_p": 0.95} - engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) - out1 = torch.tensor(engine.encode(prompt)["embedding"]) - engine.shutdown() + # Create an LLM. + llm = sgl.Engine( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + ) - runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42) - out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"]) - runtime.shutdown() + if True: + # 1. sync + non streaming + print("\n\n==== 1. sync + non streaming ====") + output = llm.generate(prompt, sampling_params) + print(output["text"]) - self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) + # 2. sync + streaming + print("\n\n==== 2. sync + streaming ====") + output_generator = llm.generate(prompt, sampling_params, stream=True) + offset = 0 + for output in output_generator: + print(output["text"][offset:], end="", flush=True) + offset = len(output["text"]) + print() - def test_7_engine_cpu_offload(self): + if True: + loop = asyncio.get_event_loop() + # 3. async + non_streaming + print("\n\n==== 3. async + non streaming ====") + output = loop.run_until_complete( + llm.async_generate(prompt, sampling_params) + ) + print(output["text"]) + + # 4. async + streaming + async def async_streaming(engine): + generator = await engine.async_generate( + prompt, sampling_params, stream=True + ) + + offset = 0 + async for output in generator: + print(output["text"][offset:], end="", flush=True) + offset = len(output["text"]) + print() + + print("\n\n==== 4. async + streaming ====") + loop.run_until_complete(async_streaming(llm)) + + llm.shutdown() + + def test_5_gsm8k(self): + + args = SimpleNamespace( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + local_data_path=None, + num_shots=5, + num_questions=200, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["accuracy"], 0.3) + + def test_6_engine_cpu_offload(self): prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -182,7 +176,7 @@ class TestSRTEngine(unittest.TestCase): print(out2) self.assertEqual(out1, out2) - def test_8_engine_offline_throughput(self): + def test_7_engine_offline_throughput(self): server_args = ServerArgs( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, )