Files
sglang/python/sglang/srt/entrypoints/EngineBase.py
tianlian yi bc92107b03 Support server based rollout in Verlengine (#4848)
Co-authored-by: Jin Pan <jpan236@wisc.edu>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Jinn <47354855+jhinpan@users.noreply.github.com>
2025-04-12 10:07:52 -07:00

54 lines
1.8 KiB
Python

from abc import ABC, abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple, Union
import torch
class EngineBase(ABC):
"""
Abstract base class for engine interfaces that support generation, weight updating, and memory control.
This base class provides a unified API for both HTTP-based engines and engines.
"""
@abstractmethod
def generate(
self,
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
) -> Union[Dict, Iterator[Dict]]:
"""Generate outputs based on given inputs."""
pass
@abstractmethod
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
flush_cache: bool = True,
):
"""Update model weights with in-memory tensor data."""
pass
@abstractmethod
def release_memory_occupation(self):
"""Release GPU memory occupation temporarily."""
pass
@abstractmethod
def resume_memory_occupation(self):
"""Resume GPU memory occupation which is previously released."""
pass
@abstractmethod
def shutdown(self):
"""Shutdown the engine and clean up resources."""
pass