Files
sglang/python/sglang/backend/runtime_endpoint.py
2024-03-11 04:43:39 -07:00

217 lines
7.0 KiB
Python

import json
from typing import Callable, List, Optional, Union
import numpy as np
import requests
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglArgument, SglSamplingParams
from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, verify=None):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
self.auth_token = auth_token
self.verify = verify
res = http_request(
self.base_url + "/get_model_info",
auth_token=self.auth_token,
verify=self.verify,
)
assert res.status_code == 200
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
def get_model_name(self):
return self.model_info["model_path"]
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
verify=self.verify,
)
assert res.status_code == 200
def commit_lazy_operations(self, s: StreamExecutor):
res = http_request(
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
verify=self.verify,
)
assert res.status_code == 200
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
verify=self.verify,
)
assert res.status_code == 200
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
verify=self.verify,
)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
data["stream"] = True
self._add_images(s, data)
response = http_request(
self.base_url + "/generate",
json=data,
stream=True,
auth_token=self.auth_token,
verify=self.verify,
)
pos = 0
incomplete_text = ""
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
text = find_printable_text(data["text"][pos:])
meta_info = data["meta_info"]
pos += len(text)
incomplete_text = data["text"][pos:]
yield text, meta_info
if len(incomplete_text) > 0:
yield incomplete_text, meta_info
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
assert temperature <= 1e-5
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
verify=self.verify,
)
assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"]
# Compute logprob
data = {
"text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
verify=self.verify,
)
assert res.status_code == 200
obj = res.json()
normalized_prompt_logprob = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj]
decision = choices[np.argmax(normalized_prompt_logprob)]
return decision, normalized_prompt_logprob, prompt_logprob
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token,
verify=self.verify,
)
assert res.status_code == 200
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]