Support server based rollout in Verlengine (#4848)
Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: Jinn <47354855+jhinpan@users.noreply.github.com>
This commit is contained in:
53
python/sglang/srt/entrypoints/EngineBase.py
Normal file
53
python/sglang/srt/entrypoints/EngineBase.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EngineBase(ABC):
|
||||
"""
|
||||
Abstract base class for engine interfaces that support generation, weight updating, and memory control.
|
||||
This base class provides a unified API for both HTTP-based engines and engines.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
|
||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||
) -> Union[Dict, Iterator[Dict]]:
|
||||
"""Generate outputs based on given inputs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_weights_from_tensor(
|
||||
self,
|
||||
named_tensors: List[Tuple[str, torch.Tensor]],
|
||||
load_format: Optional[str] = None,
|
||||
flush_cache: bool = True,
|
||||
):
|
||||
"""Update model weights with in-memory tensor data."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release_memory_occupation(self):
|
||||
"""Release GPU memory occupation temporarily."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def resume_memory_occupation(self):
|
||||
"""Resume GPU memory occupation which is previously released."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self):
|
||||
"""Shutdown the engine and clean up resources."""
|
||||
pass
|
||||
@@ -38,6 +38,7 @@ import torch
|
||||
import uvloop
|
||||
|
||||
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
|
||||
from sglang.srt.entrypoints.EngineBase import EngineBase
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
@@ -78,7 +79,7 @@ logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
class Engine:
|
||||
class Engine(EngineBase):
|
||||
"""
|
||||
The entry point to the inference engine.
|
||||
|
||||
|
||||
@@ -25,8 +25,11 @@ import multiprocessing as multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from ast import Mult
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Callable, Dict, Optional
|
||||
from typing import AsyncIterator, Callable, Dict, Optional, Union
|
||||
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -60,6 +63,7 @@ from sglang.srt.managers.io_struct import (
|
||||
SetInternalStateReq,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
VertexGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -80,6 +84,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
add_api_key_middleware,
|
||||
add_prometheus_middleware,
|
||||
delete_directory,
|
||||
@@ -411,6 +416,26 @@ async def init_weights_update_group(
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_tensor")
|
||||
async def update_weights_from_tensor(
|
||||
obj: UpdateWeightsFromTensorReqInput, request: Request
|
||||
):
|
||||
"""Update the weights from tensor inplace without re-launching the server.
|
||||
Notes:
|
||||
1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors.
|
||||
2. HTTP will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model.
|
||||
3. Any binary data in the named tensors should be base64 encoded.
|
||||
"""
|
||||
|
||||
success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
|
||||
obj, request
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
return ORJSONResponse(
|
||||
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_distributed")
|
||||
async def update_weights_from_distributed(
|
||||
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
||||
|
||||
142
python/sglang/srt/entrypoints/http_server_engine.py
Normal file
142
python/sglang/srt/entrypoints/http_server_engine.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import base64
|
||||
import copy
|
||||
import dataclasses
|
||||
import multiprocessing
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.entrypoints.EngineBase import EngineBase
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree
|
||||
|
||||
|
||||
def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:
|
||||
|
||||
p = multiprocessing.Process(target=launch_server, args=(server_args,))
|
||||
p.start()
|
||||
|
||||
base_url = server_args.url()
|
||||
timeout = 300.0 # Increased timeout to 5 minutes for downloading large models
|
||||
start_time = time.time()
|
||||
|
||||
with requests.Session() as session:
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {server_args.api_key}",
|
||||
}
|
||||
response = session.get(f"{base_url}/health_generate", headers=headers)
|
||||
if response.status_code == 200:
|
||||
return p
|
||||
except requests.RequestException:
|
||||
pass
|
||||
|
||||
if not p.is_alive():
|
||||
raise Exception("Server process terminated unexpectedly.")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
p.terminate()
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
class HttpServerEngineAdapter(EngineBase):
|
||||
"""
|
||||
You can use this class to launch a server from a VerlEngine instance.
|
||||
We recommend using this class only you need to use http server.
|
||||
Otherwise, you can use Engine directly.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.server_args = ServerArgs(**kwargs)
|
||||
print(
|
||||
f"Launch HttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}"
|
||||
)
|
||||
self.process = launch_server_process(self.server_args)
|
||||
|
||||
def _make_request(self, endpoint: str, payload: Optional[dict] = None):
|
||||
"""Make a POST request to the specified endpoint with the given payload.
|
||||
|
||||
Args:
|
||||
endpoint: The API endpoint to call
|
||||
payload: The JSON payload to send (default: empty dict)
|
||||
|
||||
Returns:
|
||||
The JSON response from the server
|
||||
"""
|
||||
url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}"
|
||||
response = requests.post(url, json=payload or {})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def update_weights_from_tensor(
|
||||
self,
|
||||
named_tensors: List[Tuple[str, torch.Tensor]],
|
||||
load_format: Optional[str] = None,
|
||||
flush_cache: bool = False,
|
||||
):
|
||||
"""
|
||||
Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
|
||||
|
||||
Note: The model should be on GPUs rather than CPU for this functionality to work properly.
|
||||
If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
|
||||
"""
|
||||
|
||||
return self._make_request(
|
||||
"update_weights_from_tensor",
|
||||
{
|
||||
"serialized_named_tensors": [
|
||||
MultiprocessingSerializer.serialize(named_tensors, output_str=True)
|
||||
for _ in range(self.server_args.tp_size)
|
||||
],
|
||||
"load_format": load_format,
|
||||
"flush_cache": flush_cache,
|
||||
},
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
kill_process_tree(self.process.pid)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt=None,
|
||||
sampling_params=None,
|
||||
input_ids=None,
|
||||
image_data=None,
|
||||
return_logprob=False,
|
||||
logprob_start_len=None,
|
||||
top_logprobs_num=None,
|
||||
token_ids_logprob=None,
|
||||
lora_path=None,
|
||||
custom_logit_processor=None,
|
||||
):
|
||||
payload = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"input_ids": input_ids,
|
||||
"image_data": image_data,
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"token_ids_logprob": token_ids_logprob,
|
||||
"lora_path": lora_path,
|
||||
"custom_logit_processor": custom_logit_processor,
|
||||
}
|
||||
# Filter out None values
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
return self._make_request("generate", payload)
|
||||
|
||||
def release_memory_occupation(self):
|
||||
return self._make_request("release_memory_occupation")
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
return self._make_request("resume_memory_occupation")
|
||||
@@ -12,16 +12,18 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL.Image import Image
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
|
||||
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
||||
|
||||
|
||||
@@ -30,6 +32,7 @@ class VerlEngine:
|
||||
self,
|
||||
device_mesh_cpu: DeviceMesh,
|
||||
nnodes: int = 1,
|
||||
backend: Literal["engine", "server"] = "engine",
|
||||
**kwargs,
|
||||
):
|
||||
monkey_patch_torch_reductions()
|
||||
@@ -40,13 +43,25 @@ class VerlEngine:
|
||||
node_rank = self._tp_rank // tp_size_per_node
|
||||
first_rank_in_node = self._tp_rank % tp_size_per_node == 0
|
||||
|
||||
if first_rank_in_node:
|
||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||
self._engine = Engine(
|
||||
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
|
||||
)
|
||||
# Common engine keyword arguments
|
||||
engine_kwargs = dict(
|
||||
**kwargs, tp_size=self._tp_size, node_rank=node_rank, nnodes=nnodes
|
||||
)
|
||||
|
||||
if backend == "engine":
|
||||
if first_rank_in_node:
|
||||
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
|
||||
self._engine = Engine(**engine_kwargs)
|
||||
else:
|
||||
self._engine = None
|
||||
|
||||
elif backend == "server":
|
||||
if self._tp_rank == 0:
|
||||
self._engine = HttpServerEngineAdapter(**engine_kwargs)
|
||||
else:
|
||||
self._engine = None
|
||||
else:
|
||||
self._engine = None
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
|
||||
dist.barrier(group=self._device_mesh_cpu.get_group())
|
||||
|
||||
|
||||
@@ -700,10 +700,17 @@ class UpdateWeightsFromDistributedReqOutput:
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromTensorReqInput:
|
||||
# List containing one serialized Dict[str, torch.Tensor] per TP worker
|
||||
serialized_named_tensors: List[bytes]
|
||||
load_format: Optional[str]
|
||||
flush_cache: bool
|
||||
"""Update model weights from tensor input.
|
||||
|
||||
- Tensors are serialized for transmission
|
||||
- Data is structured in JSON for easy transmission over HTTP
|
||||
"""
|
||||
|
||||
serialized_named_tensors: List[Union[str, bytes]]
|
||||
# Optional format specification for loading
|
||||
load_format: Optional[str] = None
|
||||
# Whether to flush the cache after updating weights
|
||||
flush_cache: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1480,14 +1480,43 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
class MultiprocessingSerializer:
|
||||
@staticmethod
|
||||
def serialize(obj):
|
||||
def serialize(obj, output_str: bool = False):
|
||||
"""
|
||||
Serialize a Python object using ForkingPickler.
|
||||
|
||||
Args:
|
||||
obj: The object to serialize.
|
||||
output_str (bool): If True, return a base64-encoded string instead of raw bytes.
|
||||
|
||||
Returns:
|
||||
bytes or str: The serialized object.
|
||||
"""
|
||||
buf = io.BytesIO()
|
||||
ForkingPickler(buf).dump(obj)
|
||||
buf.seek(0)
|
||||
return buf.read()
|
||||
output = buf.read()
|
||||
|
||||
if output_str:
|
||||
# Convert bytes to base64-encoded string
|
||||
output = base64.b64encode(output).decode("utf-8")
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data):
|
||||
"""
|
||||
Deserialize a previously serialized object.
|
||||
|
||||
Args:
|
||||
data (bytes or str): The serialized data, optionally base64-encoded.
|
||||
|
||||
Returns:
|
||||
The deserialized Python object.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
# Decode base64 string to bytes
|
||||
data = base64.b64decode(data)
|
||||
|
||||
return ForkingPickler.loads(data)
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,12 @@ from sglang.bench_serving import run_benchmark
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.openai import OpenAI
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.utils import get_bool_env_var, kill_process_tree, retry
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_port_available,
|
||||
kill_process_tree,
|
||||
retry,
|
||||
)
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -98,6 +103,17 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None)
|
||||
return pred
|
||||
|
||||
|
||||
def find_available_port(base_port: int):
|
||||
port = base_port + random.randint(100, 1000)
|
||||
while True:
|
||||
if is_port_available(port):
|
||||
return port
|
||||
if port < 60000:
|
||||
port += 42
|
||||
else:
|
||||
port -= 43
|
||||
|
||||
|
||||
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
|
||||
assert url is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user