Move sgl.Runtime under sglang/lang (#2990)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
"""Public APIs of the language."""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
@@ -33,17 +32,13 @@ def function(
|
||||
|
||||
|
||||
def Runtime(*args, **kwargs):
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
# Avoid importing unnecessary dependency
|
||||
from sglang.srt.server import Runtime
|
||||
from sglang.lang.backend.runtime_endpoint import Runtime
|
||||
|
||||
return Runtime(*args, **kwargs)
|
||||
|
||||
|
||||
def Engine(*args, **kwargs):
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
# Avoid importing unnecessary dependency
|
||||
from sglang.srt.server import Engine
|
||||
|
||||
|
||||
@@ -27,7 +27,8 @@ from sglang.bench_serving import (
|
||||
sample_random_requests,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.srt.server import Engine, Runtime
|
||||
from sglang.lang.backend.runtime_endpoint import Runtime
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import atexit
|
||||
import json
|
||||
import multiprocessing
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
@@ -14,6 +19,9 @@ from sglang.lang.ir import (
|
||||
REGEX_STR,
|
||||
SglSamplingParams,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import is_port_available, kill_process_tree
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
@@ -325,3 +333,162 @@ class RuntimeEndpoint(BaseBackend):
|
||||
def compute_normalized_prompt_logprobs(input_logprobs):
|
||||
values = [x[0] for x in input_logprobs if x[0]]
|
||||
return sum(values) / len(values)
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
A wrapper for the HTTP server.
|
||||
This is used for launching the server in a python program without
|
||||
using the commond line interface.
|
||||
|
||||
It is mainly used for the frontend language.
|
||||
You should use the Engine class if you want to do normal offline processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "error",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""See the arguments in server_args.py::ServerArgs"""
|
||||
from sglang.srt.server import launch_server
|
||||
|
||||
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
|
||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
# Pre-allocate ports
|
||||
for port in range(self.server_args.port, 40000):
|
||||
if is_port_available(port):
|
||||
break
|
||||
self.server_args.port = port
|
||||
|
||||
self.url = self.server_args.url()
|
||||
self.generate_url = self.url + "/generate"
|
||||
|
||||
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
proc = multiprocessing.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, pipe_writer),
|
||||
)
|
||||
proc.start()
|
||||
pipe_writer.close()
|
||||
self.pid = proc.pid
|
||||
|
||||
try:
|
||||
init_state = pipe_reader.recv()
|
||||
except EOFError:
|
||||
init_state = ""
|
||||
|
||||
if init_state != "ready":
|
||||
self.shutdown()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
)
|
||||
|
||||
self.endpoint = RuntimeEndpoint(self.url)
|
||||
|
||||
def shutdown(self):
|
||||
if self.pid is not None:
|
||||
kill_process_tree(self.pid)
|
||||
self.pid = None
|
||||
|
||||
def cache_prefix(self, prefix: str):
|
||||
self.endpoint.cache_prefix(prefix)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return get_tokenizer(
|
||||
self.server_args.tokenizer_path,
|
||||
tokenizer_mode=self.server_args.tokenizer_mode,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
revision=self.server_args.revision,
|
||||
)
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
):
|
||||
if self.server_args.skip_tokenizer_init:
|
||||
json_data = {
|
||||
"input_ids": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
else:
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
pos = 0
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.post(self.generate_url, json=json_data) as response:
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]\n\n":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
if "text" in data:
|
||||
cur = data["text"][pos:]
|
||||
if cur:
|
||||
yield cur
|
||||
pos += len(cur)
|
||||
else:
|
||||
yield data
|
||||
|
||||
add_request = async_generate
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
sampling_params: Optional[Dict] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
||||
response = requests.post(
|
||||
self.url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
json_data = {"text": prompt}
|
||||
response = requests.post(self.url + "/encode", json=json_data)
|
||||
return json.dumps(response.json())
|
||||
|
||||
async def get_server_info(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.url}/get_server_info") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_data = await response.json()
|
||||
raise RuntimeError(
|
||||
f"Failed to get server info. {error_data['error']['message']}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Launch the inference server for Llava-video model."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from sglang.srt.server import launch_server, prepare_server_args
|
||||
|
||||
if __name__ == "__main__":
|
||||
server_args = prepare_server_args(sys.argv[1:])
|
||||
|
||||
model_override_args = {}
|
||||
model_override_args["mm_spatial_pool_stride"] = 2
|
||||
model_override_args["architectures"] = ["LlavaVidForCausalLM"]
|
||||
model_override_args["num_frames"] = 16
|
||||
model_override_args["model_type"] = "llavavid"
|
||||
if model_override_args["num_frames"] == 32:
|
||||
model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
|
||||
model_override_args["max_sequence_length"] = 4096 * 2
|
||||
model_override_args["tokenizer_model_max_length"] = 4096 * 2
|
||||
model_override_args["model_max_length"] = 4096 * 2
|
||||
if "34b" in server_args.model_path.lower():
|
||||
model_override_args["image_token_index"] = 64002
|
||||
server_args.json_model_override_args = json.dumps(model_override_args)
|
||||
|
||||
launch_server(server_args)
|
||||
@@ -1,16 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# TODO(lmzheng): make this an optional dependency
|
||||
from sglang.srt.constrained.outlines_backend import build_regex_from_object
|
||||
@@ -18,6 +18,8 @@ from dataclasses import dataclass
|
||||
from threading import Event, Lock
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
@@ -69,3 +71,22 @@ class BaseGrammarBackend:
|
||||
def reset(self):
|
||||
with self.cache_lock:
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
|
||||
if server_args.grammar_backend == "outlines":
|
||||
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
||||
|
||||
grammar_backend = OutlinesGrammarBackend(
|
||||
tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
allow_jump_forward=not server_args.disable_jump_forward,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
||||
|
||||
grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
||||
|
||||
return grammar_backend
|
||||
|
||||
@@ -34,6 +34,7 @@ import zmq
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -149,9 +150,7 @@ class Scheduler:
|
||||
else 1
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
# Distributed rank info
|
||||
self.dp_size = server_args.dp_size
|
||||
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
||||
compute_dp_attention_world_info(
|
||||
@@ -162,6 +161,8 @@ class Scheduler:
|
||||
)
|
||||
)
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
if self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||
@@ -243,7 +244,7 @@ class Scheduler:
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
|
||||
# Launch worker for speculative decoding if need
|
||||
# Launch a worker for speculative decoding if needed
|
||||
if self.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
|
||||
@@ -316,6 +317,8 @@ class Scheduler:
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
self.spec_num_total_accepted_tokens = 0
|
||||
self.spec_num_total_forward_ct = 0
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.stream_interval = server_args.stream_interval
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
@@ -337,28 +340,9 @@ class Scheduler:
|
||||
# Init the grammar backend for constrained generation
|
||||
self.grammar_queue: List[Req] = []
|
||||
if not server_args.skip_tokenizer_init:
|
||||
if server_args.grammar_backend == "outlines":
|
||||
from sglang.srt.constrained.outlines_backend import (
|
||||
OutlinesGrammarBackend,
|
||||
)
|
||||
|
||||
self.grammar_backend = OutlinesGrammarBackend(
|
||||
self.tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
allow_jump_forward=not server_args.disable_jump_forward,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
from sglang.srt.constrained.xgrammar_backend import (
|
||||
XGrammarGrammarBackend,
|
||||
)
|
||||
|
||||
self.grammar_backend = XGrammarGrammarBackend(
|
||||
self.tokenizer, vocab_size=self.model_config.vocab_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid grammar backend: {server_args.grammar_backend}"
|
||||
)
|
||||
self.grammar_backend = create_grammar_backend(
|
||||
server_args, self.tokenizer, self.model_config.vocab_size
|
||||
)
|
||||
else:
|
||||
self.grammar_backend = None
|
||||
|
||||
@@ -424,7 +408,8 @@ class Scheduler:
|
||||
},
|
||||
)
|
||||
|
||||
self._dispatcher = TypeBasedDispatcher(
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||
@@ -480,10 +465,6 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -506,10 +487,6 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -517,7 +494,7 @@ class Scheduler:
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap scheduler.
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
@@ -593,7 +570,7 @@ class Scheduler:
|
||||
|
||||
def process_input_requests(self, recv_reqs: List):
|
||||
for recv_req in recv_reqs:
|
||||
output = self._dispatcher(recv_req)
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
@@ -798,15 +775,32 @@ class Scheduler:
|
||||
self.num_generated_tokens = 0
|
||||
self.last_decode_stats_tic = time.time()
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
logger.info(
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
msg = (
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
else:
|
||||
accept_length = (
|
||||
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
|
||||
)
|
||||
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
||||
msg = (
|
||||
f"Decode batch. "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"accept len: {accept_length:.2f}, "
|
||||
f"gen throughput (token/s): {gen_throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
)
|
||||
|
||||
logger.info(msg)
|
||||
if self.enable_metrics:
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
@@ -855,16 +849,23 @@ class Scheduler:
|
||||
else:
|
||||
self.running_batch.merge_batch(self.last_batch)
|
||||
|
||||
# Run prefill first if possible
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
if new_batch is not None:
|
||||
return new_batch
|
||||
# Run prefill first if possible
|
||||
ret = new_batch
|
||||
else:
|
||||
# Run decode
|
||||
if self.running_batch is None:
|
||||
ret = None
|
||||
else:
|
||||
self.running_batch = self.update_running_batch(self.running_batch)
|
||||
ret = self.running_batch
|
||||
|
||||
# Run decode
|
||||
if self.running_batch is None:
|
||||
return None
|
||||
self.running_batch = self.update_running_batch(self.running_batch)
|
||||
return self.running_batch
|
||||
# Handle DP attention
|
||||
if self.server_args.enable_dp_attention:
|
||||
ret = self.prepare_dp_attn_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Check if the grammar is ready in the grammar queue
|
||||
@@ -1053,6 +1054,10 @@ class Scheduler:
|
||||
model_worker_batch,
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.spec_num_total_accepted_tokens += (
|
||||
num_accepted_tokens + batch.batch_size()
|
||||
)
|
||||
self.spec_num_total_forward_ct += batch.batch_size()
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
else:
|
||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
||||
|
||||
@@ -224,7 +224,7 @@ class TokenizerManager:
|
||||
},
|
||||
)
|
||||
|
||||
self._dispatcher = TypeBasedDispatcher(
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(BatchStrOut, self._handle_batch_output),
|
||||
(BatchEmbeddingOut, self._handle_batch_output),
|
||||
@@ -760,7 +760,7 @@ class TokenizerManager:
|
||||
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._dispatcher(recv_obj)
|
||||
self._result_dispatcher(recv_obj)
|
||||
|
||||
def _handle_batch_output(
|
||||
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
|
||||
|
||||
@@ -45,8 +45,6 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
@@ -90,7 +88,6 @@ from sglang.srt.utils import (
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
delete_directory,
|
||||
is_port_available,
|
||||
kill_process_tree,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model_and_tokenizer,
|
||||
@@ -960,160 +957,3 @@ class Engine:
|
||||
obj = ResumeMemoryOccupationReqInput()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
A wrapper for the HTTP server.
|
||||
This is used for launching the server in a python program without
|
||||
using the commond line interface.
|
||||
|
||||
It is mainly used for the frontend language.
|
||||
You should use the Engine class above if you want to do normal offline processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "error",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""See the arguments in server_args.py::ServerArgs"""
|
||||
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
|
||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
# Pre-allocate ports
|
||||
for port in range(self.server_args.port, 40000):
|
||||
if is_port_available(port):
|
||||
break
|
||||
self.server_args.port = port
|
||||
|
||||
self.url = self.server_args.url()
|
||||
self.generate_url = self.url + "/generate"
|
||||
|
||||
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
|
||||
proc = mp.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, pipe_writer),
|
||||
)
|
||||
proc.start()
|
||||
pipe_writer.close()
|
||||
self.pid = proc.pid
|
||||
|
||||
try:
|
||||
init_state = pipe_reader.recv()
|
||||
except EOFError:
|
||||
init_state = ""
|
||||
|
||||
if init_state != "ready":
|
||||
self.shutdown()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
)
|
||||
|
||||
self.endpoint = RuntimeEndpoint(self.url)
|
||||
|
||||
def shutdown(self):
|
||||
if self.pid is not None:
|
||||
kill_process_tree(self.pid)
|
||||
self.pid = None
|
||||
|
||||
def cache_prefix(self, prefix: str):
|
||||
self.endpoint.cache_prefix(prefix)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return get_tokenizer(
|
||||
self.server_args.tokenizer_path,
|
||||
tokenizer_mode=self.server_args.tokenizer_mode,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
revision=self.server_args.revision,
|
||||
)
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
):
|
||||
if self.server_args.skip_tokenizer_init:
|
||||
json_data = {
|
||||
"input_ids": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
else:
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
pos = 0
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.post(self.generate_url, json=json_data) as response:
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]\n\n":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
if "text" in data:
|
||||
cur = data["text"][pos:]
|
||||
if cur:
|
||||
yield cur
|
||||
pos += len(cur)
|
||||
else:
|
||||
yield data
|
||||
|
||||
add_request = async_generate
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
sampling_params: Optional[Dict] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
||||
response = requests.post(
|
||||
self.url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
json_data = {"text": prompt}
|
||||
response = requests.post(self.url + "/encode", json=json_data)
|
||||
return json.dumps(response.json())
|
||||
|
||||
async def get_server_info(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.url}/get_server_info") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_data = await response.json()
|
||||
raise RuntimeError(
|
||||
f"Failed to get server info. {error_data['error']['message']}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server import Runtime
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
@@ -278,7 +278,7 @@ class SRTRunner:
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
self.runtime = Runtime(
|
||||
self.engine = Engine(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
dtype=get_dtype_str(torch_dtype),
|
||||
@@ -306,7 +306,7 @@ class SRTRunner:
|
||||
top_output_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = self.runtime.generate(
|
||||
response = self.engine.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
@@ -314,7 +314,6 @@ class SRTRunner:
|
||||
logprob_start_len=0,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs.append(response["text"])
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
@@ -343,8 +342,7 @@ class SRTRunner:
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
)
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
response = self.engine.encode(prompts)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
@@ -366,20 +364,18 @@ class SRTRunner:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
response = self.runtime.generate(
|
||||
response = self.engine.generate(
|
||||
prompts,
|
||||
lora_path=lora_paths if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs = [r["text"] for r in response]
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
)
|
||||
else:
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
response = self.engine.encode(prompts)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
@@ -391,8 +387,8 @@ class SRTRunner:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.runtime.shutdown()
|
||||
del self.runtime
|
||||
self.engine.shutdown()
|
||||
del self.engine
|
||||
|
||||
|
||||
def monkey_patch_gemma2_sdpa():
|
||||
|
||||
Reference in New Issue
Block a user