release initial code

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-01-08 04:37:50 +00:00
parent f6d40df0ee
commit 22085081bb
145 changed files with 17802 additions and 2 deletions

31
python/pyproject.toml Normal file
View File

@@ -0,0 +1,31 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "sglang"
version = "0.1.0"
description = "A structured generation langauge for LLMs."
readme = "README.md"
requires-python = ">=3.8"
license = {file = "LICENSE"}
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"requests",
]
[project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark"]
openai = ["openai>=1.0"]
anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]

View File

@@ -0,0 +1,2 @@
from sglang.api import *
from sglang.global_config import global_config

161
python/sglang/api.py Normal file
View File

@@ -0,0 +1,161 @@
"""Public API"""
import re
from typing import Callable, List, Optional, Union
from sglang.backend.anthropic import Anthropic
from sglang.backend.base_backend import BaseBackend
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config
from sglang.lang.ir import (
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
)
from sglang.srt.server import Runtime
def function(func: Callable):
return SglFunction(func)
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend
def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
regex: Optional[str] = None,
):
if choices:
return SglSelect(name, choices, temperature)
# check regex is valid
if regex is not None:
try:
re.compile(regex)
except re.error as e:
raise e
return SglGen(
name,
max_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
dtype,
regex,
)
def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
):
return SglGen(
name,
max_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
int,
None,
)
def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
):
return SglGen(
name,
max_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
str,
None,
)
def image(expr: SglExpr):
return SglImage(expr)
def select(
name: Optional[str] = None,
choices: List[str] = None,
temperature: float = 0.0,
):
assert choices is not None
return SglSelect(name, choices, temperature)
def _role_common(name: str, expr: Optional[SglExpr] = None):
if expr is None:
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
else:
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
def system(expr: Optional[SglExpr] = None):
return _role_common("system", expr)
def user(expr: Optional[SglExpr] = None):
return _role_common("user", expr)
def assistant(expr: Optional[SglExpr] = None):
return _role_common("assistant", expr)
def user_begin():
return SglRoleBegin("user")
def user_end():
return SglRoleEnd("user")
def assistant_begin():
return SglRoleBegin("assistant")
def assistant_end():
return SglRoleEnd("assistant")

View File

View File

@@ -0,0 +1,57 @@
from typing import List, Optional, Union
import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
try:
import anthropic
except ImportError as e:
anthropic = e
class Anthropic(BaseBackend):
def __init__(self, model_name):
super().__init__()
if isinstance(anthropic, Exception):
raise anthropic
self.model_name = model_name
self.chat_template = get_chat_template("claude")
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
prompt = s.text_
ret = anthropic.Anthropic().completions.create(
model=self.model_name,
prompt=prompt,
**sampling_params.to_anthropic_kwargs(),
)
comp = ret.completion
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
prompt = s.text_
generator = anthropic.Anthropic().completions.create(
model=self.model_name,
prompt=prompt,
stream=True,
**sampling_params.to_anthropic_kwargs(),
)
for ret in generator:
yield ret.completion, {}

View File

@@ -0,0 +1,74 @@
from typing import Callable, List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
class BaseBackend:
def __init__(self) -> None:
self.support_concate_and_append = False
self.chat_template = get_chat_template("default")
def get_model_name(self):
raise NotImplementedError()
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
pass
def uncache_prefix(self, rid: str):
pass
def end_request(self, rid: Union[str, List[str]]):
pass
def begin_program(self, s: StreamExecutor):
pass
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
pass
def commit_lazy_operations(self, s: StreamExecutor):
pass
def fork_program(
self,
src: StreamExecutor,
dst: List[StreamExecutor],
position_ids_offset: Optional[List[int]] = None,
):
pass
def fill_image(self, s: StreamExecutor):
pass
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
raise NotImplementedError()
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
raise NotImplementedError()
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
raise NotImplementedError()
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
raise NotImplementedError()
def shutdown(self):
pass

View File

@@ -0,0 +1,349 @@
import functools
from enum import Enum, auto
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import transformers
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import ProgramState
from sglang.utils import get_available_gpu_memory
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from transformersgl.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
class StopReason(Enum):
EOS_TOKEN = auto()
STOP_STR = auto()
LENGTH = auto()
def load_model(
model_name: str,
device,
num_gpus,
max_gpu_memory,
model_kwargs=None,
tokenizer_kwargs=None,
):
model_kwargs = model_kwargs or {}
tokenizer_kwargs = tokenizer_kwargs or {}
if device == "cuda":
model_kwargs["torch_dtype"] = torch.float16
if num_gpus != 1:
model_kwargs["device_map"] = "auto"
if max_gpu_memory is None:
model_kwargs[
"device_map"
] = "sequential" # This is important for not the same VRAM sizes
available_gpu_memory = [
get_available_gpu_memory(i, False) for i in range(num_gpus)
]
model_kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
for i in range(num_gpus)
}
else:
model_kwargs["max_memory"] = {
i: max_gpu_memory for i in range(num_gpus)
}
elif device == "cpu":
model_kwargs["torch_dtype"] = torch.float32
else:
raise ValueError(f"Invalid device: {device}")
model = AutoModelForCausalLM.from_pretrained(
model_name, low_cpu_mem_usage=True, **model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
if num_gpus == 1:
model.to(device).eval()
return model, tokenizer
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
@functools.lru_cache
def get_token_healing_mask(tokenizer, prompt_last_token):
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
disallowed = torch.zeros(len(tokenizer), dtype=bool)
for s, t_id in tokenizer.get_vocab().items():
if not s.startswith(last_str):
disallowed[t_id] = 1
return disallowed
@functools.lru_cache
def get_int_token_mask(tokenizer):
disallowed = torch.zeros(len(tokenizer), dtype=bool)
for s, t_id in tokenizer.get_vocab().items():
s = s.replace("", "").strip()
if not (s.isdigit() or len(s) == 0 or s == ","):
disallowed[t_id] = 1
disallowed[tokenizer.eos_token_id] = 0
return disallowed
@torch.inference_mode()
def generate_stream(
model,
tokenizer,
prompt,
max_new_tokens,
stop: List[str],
temperature,
top_p,
token_healing,
logit_mask=None,
):
logits_processor = prepare_logits_processor(
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
)
device = model.device
input_ids = tokenizer.encode(prompt)
output_ids = list(input_ids)
prompt_len = len(prompt)
# Resolve stop
stop_token_ids = [tokenizer.eos_token_id]
# Token healing
token_healing = token_healing and len(input_ids) > 0
if token_healing:
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
del output_ids[-1]
# Generate
past_key_values = None
stop_reason = None
for i in range(max_new_tokens):
# Forward
if i == 0: # prefill
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
else: # decoding
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
# Logit mask
if token_healing and i == 0:
logits[0, -1, token_healing_mask] = -1e4
if logit_mask is not None:
logits[0, -1, logit_mask] = -1e4
# Sample next token
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
if temperature < 1e-5 or top_p < 1e-8: # greedy
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
# Stop condition
if token in stop_token_ids:
stop_reason = StopReason.EOS_TOKEN
break
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
for stop_str in stop:
pos = output_str[prompt_len:].find(stop_str)
if pos != -1:
stop_reason = StopReason.STOP_STR
output_str = output_str[: prompt_len + pos]
break
if stop_reason:
break
return output_str[prompt_len:]
class HuggingFaceTransformers(BaseBackend):
def __init__(
self,
model_name,
device="cuda",
num_gpus=1,
max_gpu_memory=None,
model_kwargs=None,
tokenizer_kwargs=None,
):
self.model_name = model_name
self.device = device
self.model, self.tokenizer = load_model(
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
)
self.chat_template = get_chat_template_by_model_path(model_name)
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
pass
def uncache_prefix(self, rid: str):
pass
def end_request(self, rid: str):
pass
def begin_program(self, s: ProgramState):
pass
def end_program(self, s: ProgramState):
pass
def fill(self, s: ProgramState, text: str):
return False
def generate_internal(
self,
prompt: str,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
if dtype is None:
comp = generate_stream(
self.model,
self.tokenizer,
prompt,
max_new_tokens=max_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
token_healing=True,
)
elif dtype in [str, "str", "string"]:
comp = generate_stream(
self.model,
self.tokenizer,
prompt + '"',
max_new_tokens=max_tokens,
stop=['"'],
temperature=temperature,
top_p=top_p,
token_healing=False,
)
comp = '"' + comp + '"'
elif dtype in [int, "int"]:
logit_mask = get_int_token_mask(self.tokenizer)
comp = generate_stream(
self.model,
self.tokenizer,
prompt,
max_new_tokens=max_tokens,
stop=stop + [" ", ","],
temperature=temperature,
top_p=top_p,
token_healing=False,
logit_mask=logit_mask,
)
return comp
def generate(
self,
s: ProgramState,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
prompt = s.text
comp = self.generate_internal(
prompt, max_tokens, stop, temperature, top_p, dtype
)
return comp
def parallel_generate(
self,
s: ProgramState,
prefixes: List[str],
join_func: Callable,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
prompt = s.text
parallel_prompts = [prompt + prefix for prefix in prefixes]
comps = []
for i in range(len(parallel_prompts)):
comps.append(
self.generate_internal(
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
)
)
joined = join_func([p + c for p, c in zip(prefixes, comps)])
return joined, comps
@torch.inference_mode()
def select(
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
):
loss_fct = torch.nn.CrossEntropyLoss()
prompt = s.text
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
prompt_choices = [prompt + choice for choice in choices]
scores = []
for i in range(len(choices)):
choice_ids = self.tokenizer.encode(
prompt_choices[i], return_tensors="pt"
).to(self.model.device)
logits = self.model(choice_ids).logits
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
logprobs = torch.log(torch.softmax(logits, dim=-1))
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
idx2 = choice_ids[0, 1:]
selected_logprobs = logprobs[0, idx1, idx2]
score = selected_logprobs.mean().item()
scores.append(score)
decision = choices[np.argmax(scores)]
return decision, scores

View File

@@ -0,0 +1,241 @@
from typing import Callable, List, Optional, Union
import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
try:
import openai
import tiktoken
except ImportError as e:
openai = tiktoken = e
def create_logit_bias_int(tokenizer):
"""Get logit bias for integer numbers."""
int_token_ids = []
tokens = tokenizer._mergeable_ranks
for token, token_id in tokens.items():
s = tokenizer.decode([token_id])
if all([c.isdigit() for c in s]) or s in [" "]:
int_token_ids.append(token_id)
if len(int_token_ids) >= 300: # OpenAI API limit
break
special_tokens = tokenizer._special_tokens
mask = {t: 100 for t in int_token_ids[:299]}
mask[special_tokens["<|endoftext|>"]] = 100
return mask
CHAT_MODEL_NAMES = [
# GPT-4
"gpt-4",
"gpt-4-32k",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-0613",
"gpt-4-0314",
# GPT-3.5
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0301",
]
class OpenAI(BaseBackend):
def __init__(self, model_name, *args, **kwargs):
super().__init__()
self.client = openai.OpenAI(*args, **kwargs)
if isinstance(openai, Exception):
raise e
self.model_name = model_name
self.tokenizer = tiktoken.encoding_for_model(model_name)
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
if model_name in CHAT_MODEL_NAMES:
self.is_chat_model = True
else:
self.is_chat_model = False
self.chat_template = get_chat_template("default")
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
assert s.text_.endswith("ASSISTANT:")
prompt = s.messages_
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
comp = openai_completion(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
elif sampling_params.dtype in [str, "str", "string"]:
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_ + '"',
stop='"',
**kwargs,
)
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]:
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
logit_bias=self.logit_bias_int,
stop=[" "],
**kwargs,
)
else:
raise ValueError(f"Unknown dtype: {dtype}")
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
assert s.text_.endswith("ASSISTANT:")
prompt = s.messages_
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
return generator
else:
raise ValueError(f"Unknown dtype: {dtype}")
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
n_choices = len(choices)
token_ids = [self.tokenizer.encode(x) for x in choices]
scores = [0] * n_choices
valid = [len(x) > 0 for x in token_ids]
prompt_tokens = self.tokenizer.encode(s.text_)
max_len = max([len(x) for x in token_ids])
for step in range(max_len):
# Build logit bias
logit_bias = {}
for i in range(n_choices):
if valid[i]:
logit_bias[token_ids[i][step]] = 100
# Call API
ret = self.client.completions.create(
model=self.model_name,
prompt=prompt_tokens,
logit_bias=logit_bias,
max_tokens=1,
temperature=temperature,
)
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
# TODO:
# 1. return logits as the scores
# 2. compute logits of the full choice
# 3. consider chunk-based decoding
# Update valid
hit = False
for i in range(n_choices):
if valid[i]:
if step == len(token_ids[i]) - 1:
valid[i] = False
if ret_token == token_ids[i][step]:
scores[i] += 1
hit = True
else:
valid[i] = False
assert hit
if np.sum(valid) <= 1:
break
prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)]
return decision, scores
def openai_completion(client, is_chat=None, prompt=None, **kwargs):
try:
if is_chat:
if kwargs["stop"] is None:
kwargs.pop("stop")
ret = client.chat.completions.create(messages=prompt, **kwargs)
comp = ret.choices[0].message.content
else:
ret = client.completions.create(prompt=prompt, **kwargs)
if isinstance(prompt, (list, tuple)):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
except openai.OpenAIError as e:
print(f"OpenAI Error: {e}")
raise e
return comp
def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs):
try:
if is_chat:
generator = client.chat.completions.create(
messages=prompt, stream=True, **kwargs
)
for ret in generator:
content = ret.choices[0].delta.content
yield content or "", {}
else:
generator = client.completions.create(prompt=prompt, stream=True, **kwargs)
for ret in generator:
content = ret.choices[0].text
yield content or "", {}
except openai.OpenAIError as e:
print(f"OpenAI Error: {e}")
raise e

View File

@@ -0,0 +1,171 @@
import json
from typing import Callable, List, Optional, Union
import numpy as np
import requests
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams, SglArgument
from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
res = http_request(self.base_url + "/get_model_info")
assert res.status_code == 200
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
def get_model_name(self):
return self.model_info["model_path"]
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
)
assert res.status_code == 200
def commit_lazy_operations(self, s: StreamExecutor):
res = http_request(
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
)
assert res.status_code == 200
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
data["stream"] = True
self._add_images(s, data)
response = http_request(self.base_url + "/generate", json=data, stream=True)
pos = 0
incomplete_text = ""
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
text = find_printable_text(data["text"][pos:])
meta_info = data["meta_info"]
pos += len(text)
incomplete_text = data["text"][pos:]
yield text, meta_info
if len(incomplete_text) > 0:
yield incomplete_text, meta_info
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
assert temperature <= 1e-5
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"]
# Compute logprob
data = {
"text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0},
"return_normalized_logprob": True,
"normalized_logprob_start_len": prompt_len,
}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200
logps = [r["meta_info"]["normalized_logprob"] for r in res.json()]
decision = choices[np.argmax(logps)]
return decision, logps
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
)
assert res.status_code == 200
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]

View File

@@ -0,0 +1,190 @@
import re
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from itertools import repeat
from typing import List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.utils import http_request
class TGI(BaseBackend):
def __init__(self, base_url):
super().__init__()
self.base_url = base_url
res = http_request(self.base_url + "/info")
assert res.status_code == 200
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_id"]
)
def get_model_name(self):
return self.model_info["model_id"]
def get_chat_template(self):
return self.chat_template
@staticmethod
def adapt_params(max_tokens, stop, sampling_params, **override_params):
temperature = sampling_params.temperature
do_sample = True
if temperature == 0:
do_sample = False
temperature = None
if stop is None:
stop = []
elif isinstance(stop, str):
stop = [stop]
top_p = sampling_params.top_p
if top_p == 0:
top_p = 0.001
if top_p == 1:
top_p = 0.999
top_k = sampling_params.top_k
if top_k == -1:
top_k = None
params = {
"decoder_input_details": False,
"details": False,
"do_sample": do_sample,
"max_new_tokens": max_tokens,
"stop": stop,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"return_full_text": False,
}
params.update(override_params)
return params
@staticmethod
def _extract_int(text):
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
for word in words:
try:
int(word)
return word
except ValueError:
continue
raise ValueError
@staticmethod
def _extract_choice(choices, text):
# FIXME: Current only support the case where the choices are single words.
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
for word in words:
if word in choices:
return word
raise ValueError
@staticmethod
def _truncate_to_stop(text, stop):
# The stop sequence may not be a single token. In this case TGI will generate
# too many tokens so we need to truncate the output.
if stop:
stop = [stop] if isinstance(stop, str) else stop
for stop_seq in stop:
pos = text.find(stop_seq)
if pos != -1:
return text[:pos]
return text
def _make_request(self, params):
res = http_request(self.base_url + "/generate", json=params)
if res.status_code != 200:
raise ValueError(f"Error from TGI backend: {res.text}")
return res.json()
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
failed = []
while retry > 0:
res_json = self._make_request(
{
"inputs": prompt,
"parameters": params,
}
)
text = res_json["generated_text"]
try:
return extract_fn(text)
except ValueError:
retry -= 1
failed.append(text)
msg = "=" * 20 + "\n"
msg += f"Prompt:\n{prompt}\n"
msg += "=" * 20 + "\n"
for i, text in enumerate(failed):
msg += f"====== Try {i+1}:\n{text}\n"
raise ValueError(
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
"expected output. Please improve the prompt, increase the temperature, or "
f"use different models.\n{msg}"
)
def select(
self,
s: StreamExecutor,
choices: List[str],
sampling_params: SamplingParams,
):
decision = self.retry_for_expected(
s.text_,
self.adapt_params(16, [], sampling_params),
partial(self._extract_choice, choices),
)
return decision, [1 if choice == decision else 0 for choice in choices]
def generate(
self,
s: StreamExecutor,
max_tokens: int,
stop: Union[str, List[str]],
sampling_params: SamplingParams,
dtype: Optional[str] = None,
):
if dtype is None:
res_json = self._make_request(
{
"inputs": s.text_,
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
}
)
return self._truncate_to_stop(res_json["generated_text"], stop), {}
if dtype in [str, "str", "string"]:
stop = ['"']
res_json = self._make_request(
{
"inputs": f'{s.text_}"',
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
}
)
return (
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
{},
)
if dtype in [int, "int"]:
return (
self.retry_for_expected(
s.text_,
self.adapt_params(max_tokens, stop, sampling_params),
self._extract_int,
),
{},
)
raise ValueError(f"Unknown dtype: {dtype}")

View File

@@ -0,0 +1,60 @@
"""Flush cache in the backend by sending random requests."""
import argparse
import random
import string
import time
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
import sglang as sgl
@sgl.function
def flush_radix_cache(s, prompt):
s += prompt + sgl.gen("flush", max_tokens=1, stop="END")
def main(args, max_total_tokens, context_length, print_flag):
backend = select_sglang_backend(args)
flush_length = int(context_length * 0.8)
batch_size = int(max_total_tokens / flush_length)
prompt_length = flush_length * 2
prompts = [
" ".join(random.choices(string.ascii_letters, k=int(prompt_length)))
for _ in range(batch_size)
]
arguments = [{"prompt": prompts[i]} for i in range(batch_size)]
start_time = time.time()
flush_radix_cache.run_batch(
arguments, temperature=0, backend=backend, num_threads=1
)
end_time = time.time()
if print_flag:
print(
f"Flush length: {flush_length}\n",
f"Prompt length: {prompt_length}\n",
f"Total Prompt letters: {batch_size * prompt_length}\n",
f"Flush radix cache latency: {end_time - start_time:.3f}",
sep="",
)
# to prevent the backend still running
time.sleep(1)
def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False):
main(args, max_total_tokens, context_length, print_flag=print_flag)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-total-tokens", type=int, default=20000)
parser.add_argument("--context-length", type=int, default=1024)
args = add_common_sglang_args_and_parse(parser)
random.seed(0)
main(args, args.max_total_tokens, args.context_length, print_flag=True)

View File

@@ -0,0 +1,28 @@
"""Global configurations"""
class GlobalConfig:
def __init__(self):
# Verbosity level
# 0: do not output anything
# 2: output final text after every run
self.verbosity = 0
self.default_backend = None
# Output configs
self.skip_special_tokens_in_output = True
# Optimization configs
self.eager_fill_image = False
self.enable_prefix_sharing = True
self.enable_parallel_encoding = True
self.enable_parallel_decoding = True
# Choices: ["no_adjust", "adjust_cache"]
# no_adjust: Do not adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
self.concate_and_append_mode = "no_adjust"
global_config = GlobalConfig()

View File

View File

@@ -0,0 +1,186 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum):
PLAIN = auto()
LLAMA2 = auto()
@dataclass
class ChatTemplate:
name: str
default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str]]
image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(self, role, hist_messages):
if self.style == ChatTemplateStyle.PLAIN:
return self.role_prefix_and_suffix[role]
elif self.style == ChatTemplateStyle.LLAMA2:
if len(hist_messages) == 0 and role == "system":
return (
self.role_prefix_and_suffix["user"][0]
+ self.role_prefix_and_suffix["system"][0],
self.role_prefix_and_suffix["system"][1],
)
elif (
len(hist_messages) == 1
and role == "user"
and hist_messages[0]["content"] is not None
):
return ("", self.role_prefix_and_suffix["user"][1])
return self.role_prefix_and_suffix[role]
else:
raise ValueError(f"Invalid style: {self.style}")
def get_prompt(self, messages):
prompt = ""
for i in range(len(messages)):
role, content = messages[i]["role"], messages[i]["content"]
if role == "system" and content is None:
content = self.default_system_prompt
if content is None:
continue
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
prompt += prefix + content + suffix
return prompt
chat_template_registry: Dict[str, ChatTemplate] = {}
matching_function_registry: List[Callable] = []
def register_chat_template(template):
chat_template_registry[template.name] = template
def register_chat_template_matching_function(func):
matching_function_registry.append(func)
def get_chat_template(name):
return chat_template_registry[name]
def get_chat_template_by_model_path(model_path):
for matching_func in matching_function_registry:
template = matching_func(model_path)
if template is not None:
return template
return get_chat_template("default")
register_chat_template(
ChatTemplate(
name="default",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("SYSTEM:", "\n"),
"user": ("USER:", "\n"),
"assistant": ("ASSISTANT:", "\n"),
},
)
)
register_chat_template(
ChatTemplate(
name="claude",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("\n\nHuman: ", ""),
"assistant": ("\n\nAssistant:", ""),
},
)
)
register_chat_template(
ChatTemplate(
name="chatml",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
)
)
register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
default_system_prompt=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
role_prefix_and_suffix={
"system": ("", " "),
"user": ("USER:", " "),
"assistant": ("ASSISTANT:", "</s>"),
},
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="llama-2-chat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
style=ChatTemplateStyle.LLAMA2,
)
)
@register_chat_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower():
return get_chat_template("vicuna_v1.1")
@register_chat_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return get_chat_template("llama-2-chat")
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return get_chat_template("llama-2-chat")
if "codellama" in model_path and "instruct" in model_path:
return get_chat_template("llama-2-chat")
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
if "tinyllama" in model_path.lower():
return get_chat_template("chatml")
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "What can you do?"},
{"role": "assistant", "content": "I can chat with you."},
]
template = get_chat_template("llama-2-chat")
print(template.get_prompt(messages))

View File

@@ -0,0 +1,237 @@
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SamplingParams,
SglArgument,
SglConstantText,
SglExpr,
SglVariable,
)
def compile_func(function, backend):
tracer = function.trace(backend=backend)
compiler = CompiledFunction(tracer, function)
return compiler
class CompiledFunction:
def __init__(self, tracer, function):
self.function = function
self.last_node = CompGraphNode(tracer.last_node)
self.expr_to_node = {}
self.build_graph(tracer)
self.topological_sort()
def build_graph(self, tracer):
self.nodes = [self.last_node]
self.expr_to_node[tracer.last_node] = self.nodes[-1]
rename_pid = {}
visited = set([tracer.last_node])
head = 0
while head < len(self.nodes):
cur_node = self.nodes[head]
# add prev node
prev_node = cur_node.expr.prev_node
if prev_node is not None:
if prev_node not in visited:
visited.add(prev_node)
self.nodes.append(CompGraphNode(prev_node))
self.expr_to_node[prev_node] = self.nodes[-1]
cur_node.prev_node = self.expr_to_node[prev_node]
self.expr_to_node[prev_node].add_next_node(cur_node)
# add source node
if isinstance(cur_node.expr, SglVariable):
if cur_node.expr.name in tracer.variables:
source = tracer.variables[cur_node.expr.name].source
else:
source = cur_node.expr.source
if source not in visited:
visited.add(source)
self.nodes.append(CompGraphNode(source))
self.expr_to_node[source] = self.nodes[-1]
cur_node.source_node = self.expr_to_node[source]
self.expr_to_node[source].add_next_node(cur_node)
head += 1
# rename pid
if cur_node.expr.pid not in rename_pid:
rename_pid[cur_node.expr.pid] = len(rename_pid)
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
def topological_sort(self):
prevd = {}
cand = Queue()
for x in self.nodes:
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
if prevd[x] == 0:
cand.put(x)
new_list = []
while cand.qsize() > 0:
head = cand.get()
new_list.append(head)
for x in head.next_nodes:
prevd[x] -= 1
if prevd[x] == 0:
cand.put(x)
self.nodes = new_list
def print_graph(
self,
):
for node in self.nodes:
print(node)
def run_internal(
self,
backend,
kwargs,
default_sampling_para,
):
stream_executor_ids = set([x.expr.pid for x in self.nodes])
stream_executors = {}
for x in stream_executor_ids:
arguments = kwargs if x == self.last_node.expr.pid else {}
stream_executors[x] = StreamExecutor(
backend, arguments, default_sampling_para, None, False
)
for node in self.nodes:
se_id = node.expr.pid
expr = node.expr
if isinstance(expr, SglVariable):
# Make a copy for SglVariable
expr = SglVariable(expr.name, expr.source)
expr.source_stream_executor = stream_executors[
node.source_node.expr.pid
]
elif isinstance(expr, SglArgument):
# Substitute SglArgument
expr = kwargs[expr.name]
stream_executors[se_id].submit(expr)
for stream_executor in stream_executors.values():
stream_executor.end()
return ProgramState(stream_executors[self.last_node.expr.pid])
def run(
self,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
**kwargs,
):
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
kwargs.update(self.function.bind_arguments)
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
return self.run_internal(backend, kwargs, default_sampling_para)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
):
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
backend = backend or global_config.default_backend
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:
pin_program(self.function, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_kwargs))
if num_threads == 1:
rets = []
for arguments in batch_kwargs:
rets.append(
self.run_internal(backend, arguments, default_sampling_para)
)
else:
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_kwargs:
futures.append(
executor.submit(
self.run_internal, backend, arguments, default_sampling_para
)
)
rets = [f.result() for f in futures]
rets[-1].sync()
return rets
class CompGraphNode:
def __init__(
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
):
self.expr = expr
self.next_nodes = next_nodes or []
self.prev_node = prev_node
self.source_node = source_node
def add_next_node(self, other):
self.next_nodes.append(other)
def __repr__(self):
re = f"stream {self.expr.pid:2d}: "
re += f"%{self.expr.node_id} = "
if self.prev_node is not None:
re += f"%{self.prev_node.expr.node_id} + "
re += repr(self.expr)
return re

View File

@@ -0,0 +1,697 @@
"""The interpreter that executes SGL programs"""
import asyncio
import multiprocessing
import queue
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
import tqdm
from sglang.global_config import global_config
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
from sglang.utils import encode_image_base64
def run_internal(state, program, func_args, func_kwargs, sync):
try:
state.ret_value = program.func(state, *func_args, **func_kwargs)
except Exception as e:
raise e
finally:
state.stream_executor.end()
if sync:
state.stream_executor.sync()
if global_config.verbosity >= 2:
print(state.text())
def run_program(
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
):
assert backend is not None, "Please specify a backend"
func_kwargs.update(program.bind_arguments)
stream_executor = StreamExecutor(
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
)
state = ProgramState(stream_executor)
if stream:
t = threading.Thread(
target=run_internal, args=(state, program, func_args, func_kwargs, sync)
)
t.start()
return state
else:
run_internal(state, program, func_args, func_kwargs, sync)
return state
def run_program_batch(
program,
backend,
batch_arguments,
default_sampling_para,
num_threads,
progress_bar,
):
# Extract prefix by tracing and cache it
if len(batch_arguments) > 1:
pin_program(program, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_arguments))
if num_threads == 1:
rets = []
for arguments in batch_arguments:
rets.append(
run_program(
program, backend, (), arguments, default_sampling_para, False, False
)
)
else:
if progress_bar:
pbar = tqdm.tqdm(total=len(batch_arguments))
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_arguments:
futures.append(
executor.submit(
run_program,
program,
backend,
(),
arguments,
default_sampling_para,
False,
False,
)
)
if progress_bar:
futures[-1].add_done_callback(lambda _: pbar.update())
rets = [f.result() for f in futures]
rets[-1].sync()
if progress_bar:
pbar.close()
return rets
def pin_program(program, backend):
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
# TODO: handle multiple backends
from sglang.lang.tracer import extract_prefix_by_tracing
prefix = extract_prefix_by_tracing(program, backend)
if prefix and len(prefix) > 64:
prefix_rid = backend.cache_prefix(prefix)
program.pin_prefix_rid = prefix_rid
return prefix_rid
return None
def unpin_program(program, backend):
pass
class StreamExecutor:
"""A stream executor that executes SGL expressions in a background thread."""
def __init__(
self,
backend,
arguments,
default_sampling_para,
chat_template,
stream,
use_thread=True,
):
self.sid = uuid.uuid4().hex
self.backend = backend
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event]
self.meta_info = {} # Dict[name: str -> info: str]
self.is_finished = False
# For completion
self.text_ = "" # The full text
# For chat
self.messages_ = [] # The messages in the OpenAI API format
self.chat_template = chat_template or self.backend.get_chat_template()
self.cur_role = None
self.cur_role_begin_pos = None
# For vision
self.images_ = []
self.cur_images = []
# For fork/join
self.fork_start_text_pos = None
# Worker thread
self.use_thread = use_thread
if self.use_thread:
self.queue = queue.Queue()
self.worker = threading.Thread(target=self._thread_worker_func)
self.worker.start()
# For streaming
if stream:
self.stream_text_event = threading.Event()
self.stream_var_event = {}
else:
self.stream_text_event = None
self.stream_var_event = None
def submit(self, expr: SglExpr):
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[expr.name] = threading.Event()
if self.stream:
self.stream_var_event[expr.name] = threading.Event()
elif isinstance(expr, SglExprList):
for e in expr.expr_list:
if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[e.name] = threading.Event()
if self.stream:
self.stream_var_event[e.name] = threading.Event()
if self.use_thread:
self.queue.put(expr)
else:
self._execute(expr)
def sync(self):
if self.use_thread:
self.queue.join()
def get_var(self, name):
if name in self.variable_event:
self.variable_event[name].wait()
return self.variables[name]
def get_meta_info(self, name):
if name in self.variable_event:
self.variable_event[name].wait()
ret = self.meta_info.get(name, None)
return ret
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
if number > 1:
self.submit(SglCommitLazy())
self.sync()
number = int(number)
exes = [
StreamExecutor(
self.backend,
self.arguments,
self.default_sampling_para,
self.chat_template,
self.stream,
)
for _ in range(number)
]
for i in range(number):
exes[i].variables = dict(self.variables)
exes[i].text_ = str(self.text_)
exes[i].messages_ = list(self.messages_)
exes[i].cur_role = self.cur_role
exes[i].fork_start_text_pos = len(self.text_)
return exes
def text(self):
self.sync()
return self.text_
def messages(self):
self.sync()
return self.messages_
def end(self):
if self.use_thread:
if self.worker.is_alive():
self.queue.put(None)
self.backend.end_program(self)
def _thread_worker_func(self):
while True:
expr = self.queue.get()
if expr is None:
self.queue.task_done()
break
self._execute(expr)
self.queue.task_done()
if self.stream_text_event:
self.stream_text_event.set()
if self.stream_text_event:
self.stream_text_event.set()
self.is_finished = True
def _execute(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
if isinstance(other, (SglConstantText, SglArgument)):
self._execute_fill(other.value)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglImage):
self._execute_image(other)
elif isinstance(other, SglVariable):
self._execute_variable(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
elif isinstance(other, SglCommitLazy):
self._execute_commit_lazy_operations(other)
elif isinstance(other, SglConcateAndAppend):
if (
global_config.enable_parallel_encoding
and self.backend.support_concate_and_append
):
self._execute_concatenate_and_append_kv_cache(other)
else:
self._execute_concatenate_and_append_text(other)
else:
raise ValueError(f"Unknown type: {type(other)}")
def _execute_fill(self, value: str):
value = str(value)
self.text_ += value
def _execute_image(self, expr: SglImage):
path = expr.path
if isinstance(path, SglArgument):
path = path.value
base64_data = encode_image_base64(path)
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
# if global_config.eager_fill_image:
# self.backend.fill_image(self)
def _execute_gen(self, expr: SglGen):
sampling_params = self._resolve_sampling_params(expr.sampling_params)
name = expr.name
if not self.stream:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
self.text_ += comp
self.variables[name] = comp
self.meta_info[name] = meta_info
self.variable_event[name].set()
else:
generator = self.backend.generate_stream(
self, sampling_params=sampling_params
)
self.stream_var_event[name].set()
self.variables[name] = ""
for comp, meta_info in generator:
self.text_ += comp
self.variables[name] += comp
self.stream_var_event[name].set()
self.stream_text_event.set()
self.meta_info[name] = meta_info
self.variable_event[name].set()
self.stream_var_event[name].set()
def _execute_select(self, expr: SglSelect):
decision, scores = self.backend.select(self, expr.choices, expr.temperature)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.variable_event[name].set()
self.text_ += decision
def _execute_variable(self, expr: SglVariable):
src_executor = expr.source_stream_executor
value = src_executor.get_var(expr.name)
self._execute_fill(value)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert the default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(prefix)
self.cur_role_begin_pos = len(self.text_)
def _execute_role_end(self, expr: SglRoleEnd):
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(suffix)
if self.cur_images:
# OpenAI vision API format
last_msg = {
"role": expr.role,
"content": [{"type": "text", "text": new_text}],
}
for (image_path, image_base64_data) in self.cur_images:
last_msg["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64_data}"
},
}
)
self.messages_.append(last_msg)
self.cur_images = []
else:
self.messages_.append({"role": expr.role, "content": new_text})
self.cur_role = None
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
self.variables[expr.name] = int(len(self.text_))
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
self.variables[expr.name] = self.text_[self.variables[expr.name] :]
self.variable_event[expr.name].set()
def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
self.backend.commit_lazy_operations(self)
def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
new_text = ""
for s in expr.states:
exe = s.stream_executor
exe.sync()
new_text += exe.text_[exe.fork_start_text_pos :]
self._execute_fill(new_text)
def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
self_len = len(self.text_)
for i, s in enumerate(expr.states):
exe = s.stream_executor
exe.submit(SglCommitLazy())
for i, s in enumerate(expr.states):
exe = s.stream_executor
exe.sync()
assert exe.fork_start_text_pos == self_len
self.text_ += exe.text_[exe.fork_start_text_pos :]
src_rids = [state.stream_executor.sid for state in expr.states]
self.backend.concatenate_and_append(src_rids, self.sid)
def _resolve_sampling_params(self, sampling_params):
clone = None
for item in [
"max_new_tokens",
"stop",
"temperature",
"top_p",
"top_k",
"frequency_penalty",
"presence_penalty",
"dtype",
"regex",
]:
value = getattr(sampling_params, item, None)
if value is not None:
if clone is None:
clone = self.default_sampling_para.clone()
setattr(clone, item, value)
return clone or self.default_sampling_para
def __del__(self):
self.end()
class ProgramState:
"""The state of an SGL program."""
def __init__(self, stream_executor: StreamExecutor):
self.stream_executor = stream_executor
def _role_common(self, name: str, expr: Optional[SglExpr] = None):
if expr is not None:
self.stream_executor.submit(
SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
)
else:
@contextmanager
def role_scope():
self.stream_executor.submit(SglRoleBegin(name))
yield
self.stream_executor.submit(SglRoleEnd(name))
return role_scope()
def system(self, expr: Optional[SglExpr] = None):
return self._role_common("system", expr)
def user(self, expr: Optional[SglExpr] = None):
return self._role_common("user", expr)
def assistant(self, expr: Optional[SglExpr] = None):
return self._role_common("assistant", expr)
@contextmanager
def var_scope(self, name: str):
self.stream_executor.submit(SglVarScopeBegin(name))
yield
self.stream_executor.submit(SglVarScopeEnd(name))
def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None):
stream_executors = self.stream_executor.fork(number, position_ids_offset)
states = [ProgramState(x) for x in stream_executors]
state_group = ProgramStateGroup(states, self)
return state_group
@contextmanager
def copy(self, position_ids_offset: Optional[List[int]] = None):
state_group = self.fork(1, position_ids_offset)
try:
yield state_group[0]
finally:
state_group.join()
def text(self):
return self.stream_executor.text()
def messages(self):
return self.stream_executor.messages()
def sync(self):
return self.stream_executor.sync()
def text_iter(self, var_name=None):
if self.stream_executor.stream:
prev = 0
if var_name is None:
event = self.stream_executor.stream_text_event
while True:
event.wait()
event.clear()
out = str(self.stream_executor.text_[prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.is_finished:
break
else:
event = self.stream_executor.stream_var_event[var_name]
while True:
event.wait()
event.clear()
out = str(self.stream_executor.variables[var_name][prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.variable_event[var_name].is_set():
break
else:
if var_name is None:
yield self.text()
else:
yield self.get_var(name)
async def text_async_iter(self, var_name=None):
loop = asyncio.get_running_loop()
if self.stream_executor.stream:
prev = 0
if var_name is None:
event = self.stream_executor.stream_text_event
while True:
await loop.run_in_executor(None, event.wait)
event.clear()
out = str(self.stream_executor.text_[prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.is_finished:
break
else:
event = self.stream_executor.stream_var_event[var_name]
while True:
await loop.run_in_executor(None, event.wait)
event.clear()
out = str(self.stream_executor.variables[var_name][prev:])
prev += len(out)
if out:
yield out
if self.stream_executor.variable_event[var_name].is_set():
break
else:
if var_name is None:
yield self.text()
else:
yield self.get_var(name)
def get_var(self, name):
return self.stream_executor.get_var(name)
def get_meta_info(self, name):
return self.stream_executor.get_meta_info(name)
def __iadd__(self, other):
self.stream_executor.submit(other)
return self
def __getitem__(self, name):
return self.get_var(name)
def __del__(self):
self.stream_executor.end()
def __repr__(self) -> str:
msgs = self.messages()
ret = ""
for msg in msgs:
ret += msg["role"] + ":\n" + msg["content"] + "\n"
return ret
class ProgramStateGroup:
def __init__(
self, states: List[ProgramState], src_state: Optional[ProgramState] = None
):
self.states = states
self.src_state = src_state
def join(self, mode: str = "gather_variable"):
if mode == "gather_variable":
# Copy variables back
src_vars = self.src_state.stream_executor.variables
src_var_set = set(src_vars.keys())
for child_state in self.states:
child_state.stream_executor.sync()
child_vars = child_state.stream_executor.variables
new_vars = set(child_vars.keys()) - src_var_set
for k in new_vars:
if k in src_vars:
src_vars[k].append(child_vars[k])
else:
src_vars[k] = [child_vars[k]]
elif mode == "concate_and_append":
# Concatenate and append KV cache
self.src_state += SglConcateAndAppend(self.states)
# Need a sync here. Otherwise, `states` can be deleted.
self.src_state.stream_executor.sync()
else:
raise ValueError(f"Invalid join mode: {mode}")
for s in self.states:
s.stream_executor.end()
def __getitem__(self, i: int):
return self.states[i]
def __setitem__(self, i: int, value):
assert self.states[i] == value
def __iadd__(self, other):
if isinstance(other, Callable):
# lambda function
for i in range(len(self.states)):
self.states[i] += other(i)
elif isinstance(other, SglExpr):
for i in range(len(self.states)):
self.states[i] += other
elif isinstance(other, (list, tuple)):
for i in range(len(self.states)):
self.states[i] += other[i]
else:
raise ValueError(f"Invalid value: {other}")
return self

442
python/sglang/lang/ir.py Normal file
View File

@@ -0,0 +1,442 @@
"""The intermediate representation."""
import dataclasses
import inspect
from typing import List, Optional, Union
from sglang.global_config import global_config
@dataclasses.dataclass
class SamplingParams:
max_new_tokens: int = 16
stop: Union[str, List[str]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SamplingParams(
self.max_new_tokens,
self.stop,
self.temperature,
self.top_p,
self.top_k,
self.frequency_penalty,
self.presence_penalty,
)
def to_openai_kwargs(self):
# OpenAI does not support top_k, so we drop it here
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_anthropic_kwargs(self):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"stop": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"regex": self.regex,
}
class SglFunction:
def __init__(self, func, bind_arguments=None):
self.func = func
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None
# Parse arguments
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == "s", 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
def bind(self, **kwargs):
assert all(key in self.arg_names for key in kwargs)
new_bind_dict = {**self.bind_arguments, **kwargs}
return SglFunction(self.func, bind_arguments=new_bind_dict)
def run(
self,
*args,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
stream: bool = False,
backend=None,
**kwargs,
):
from sglang.lang.interpreter import run_program
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 16,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
):
from sglang.lang.interpreter import run_program_batch
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
default_sampling_para = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
backend = backend or global_config.default_backend
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
return run_program_batch(
self,
backend,
batch_kwargs,
default_sampling_para,
num_threads,
progress_bar,
)
def trace(self, *, backend=None, **kwargs):
from sglang.lang.tracer import trace_program
backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend)
def pin(self, backend=None):
from sglang.lang.interpreter import pin_program
backend = backend or global_config.default_backend
return pin_program(self, backend)
def unpin(self, backend=None):
from sglang.lang.interpreter import unpin_program
backend = backend or global_config.default_backend
return unpin_program(self, backend)
def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func
return compile_func(self, backend)
def __call__(self, *args, **kwargs):
from sglang.lang.tracer import TracingScope
tracing_scope = TracingScope.get_current_scope()
if tracing_scope is None:
return self.run(*args, **kwargs)
else:
kwargs["backend"] = tracing_scope.tracer_state.backend
return self.trace(*args, **kwargs)
class SglExpr:
node_ct = 0
def __init__(self):
self.node_id = SglExpr.node_ct
self.prev_node = None
self.pid = None
SglExpr.node_ct += 1
def __add__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr)
return self.concatenate_ir(self, other)
def __radd__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
return self.concatenate_ir(other, self)
def concatenate_ir(self, a, b):
if isinstance(a, SglExprList):
if isinstance(b, SglExprList):
return SglExprList(a.expr_list + b.expr_list)
else:
return SglExprList(a.expr_list + [b])
elif isinstance(b, SglExprList):
return SglExprList([a] + b.expr_list)
return SglExprList([a, b])
def print_graph_dfs(self):
ret = [""]
visited = set()
def dfs_print(x):
if x is None or x in visited:
return
visited.add(x)
# Print dependency
if x.prev_node is not None:
dfs_print(x.prev_node)
if isinstance(x, SglExprList):
for y in x.expr_list:
dfs_print(y)
# elif isinstance(x, SglRole):
# dfs_print(x.expr)
elif isinstance(x, SglVariable):
dfs_print(x.source)
# Print the node itself
if isinstance(x, (SglFork, SglGetForkItem)):
ret[0] += f"%{x.node_id} = {x}\n"
else:
if x.prev_node is not None:
ret[0] += (
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
)
else:
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
dfs_print(self)
return ret[0]
class SglExprList(SglExpr):
def __init__(self, expr_list: List[SglExpr]):
super().__init__()
self.expr_list = expr_list
def __repr__(self):
return f"ExprList({self.expr_list})"
class SglArgument(SglExpr):
def __init__(self, name: str, value: str):
super().__init__()
self.name = name
self.value = value
def __repr__(self):
return f"Argument(name={self.name}, value={repr(self.value)})"
def __len__(self):
return len(self.value)
def __getitem__(self, i):
return self.value[i]
def __int__(self):
return self.value
def __bool__(self):
return self.value
def __format__(self, *args):
raise TypeError(
"Cannot put argument inside a f-string. "
"This is not compatible with the tracer. "
)
class SglImage(SglExpr):
def __init__(self, path):
self.path = path
def __repr__(self) -> str:
return f"SglImage({self.path})"
class SglGen(SglExpr):
def __init__(
self,
name,
max_new_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
dtype,
regex,
):
super().__init__()
self.name = name
self.sampling_params = SamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
dtype=dtype,
regex=regex,
)
def __repr__(self):
return f"Gen('{self.name}')"
class SglConstantText(SglExpr):
def __init__(self, value):
super().__init__()
self.value = value
def __repr__(self):
return f"Constant({repr(self.value)})"
class SglRoleBegin(SglExpr):
def __init__(self, role):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleBegin({self.role})"
class SglRoleEnd(SglExpr):
def __init__(self, role):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleEnd({self.role})"
class SglSelect(SglExpr):
def __init__(self, name, choices, temperature):
super().__init__()
self.name = name
self.choices = choices
self.temperature = temperature
def __repr__(self):
return f"Select({self.name}, choices={self.choices})"
class SglFork(SglExpr):
def __init__(self, number, position_ids_offset=None):
super().__init__()
self.number = number
self.position_ids_offset = position_ids_offset
def __repr__(self):
return (
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
f"position_ids_offset={self.position_ids_offset})"
)
class SglGetForkItem(SglExpr):
def __init__(self, index):
super().__init__()
self.index = index
def __repr__(self):
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
class SglVariable(SglExpr):
def __init__(self, name, source):
super().__init__()
self.name = name
self.source = source
def __repr__(self):
return f"Variable('{self.name}', source=%{self.source.node_id})"
class SglVarScopeBegin(SglExpr):
def __init__(self, name):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeBegin('{self.name}')"
class SglVarScopeEnd(SglExpr):
def __init__(self, name):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeEnd('{self.name}')"
class SglConcateAndAppend(SglExpr):
def __init__(self, states):
super().__init__()
self.states = states
def __repr__(self):
return f"ConcatenateAndAppend('{self.states}')"
class SglCommitLazy(SglExpr):
def __init__(self):
super().__init__()
def __repr__(self):
return f"CommitLazy()"

View File

@@ -0,0 +1,279 @@
"""Tracing a program."""
import uuid
from typing import Any, Callable, Dict, List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglCommitLazy,
SglConcateAndAppend,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglFunction,
SglGen,
SglGetForkItem,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
class StopTracing(Exception):
pass
def extract_prefix_by_tracing(program, backend):
# Create dummy arguments
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
arguments = dummy_arguments
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except StopTracing:
pass
# Run and cache prefix
prefix = ""
for expr in tracer.flatten_nodes():
if isinstance(expr, SglConstantText):
prefix += expr.value
else:
break
return prefix
def trace_program(program, arguments, backend):
# Create dummy backend
if backend is None:
backend = BaseBackend()
# Create dummy arguments
dummy_arguments = {
name: SglArgument(name, None)
for name in program.arg_names
if name not in arguments
}
arguments.update(dummy_arguments)
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
return tracer
class TracerProgramState(ProgramState):
def __init__(self, backend, arguments, only_trace_prefix):
self.pid = uuid.uuid4().hex
self.backend = backend
self.arguments: Dict[str, Any] = arguments
self.only_trace_prefix = only_trace_prefix
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.nodes = []
self.last_node = None
self.variables = {}
self.ret_value = None
# For completion
# For chat
self.messages_ = []
self.cur_role = None
self.chat_template = self.backend.get_chat_template()
# For multi states
self.child_states = []
cur_scope = TracingScope.get_current_scope()
if cur_scope is not None:
cur_scope.add_child_state(self)
##################################
########### Public API ###########
##################################
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
if self.only_trace_prefix:
raise StopTracing()
fork_node = SglFork(number)
fork_node.prev_node = self.last_node
states = [
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
for _ in range(number)
]
for i in range(number):
node = SglGetForkItem(i)
node.prev_node = fork_node
states[i].last_node = node
states[i].variables = dict(self.variables)
states[i].messages_ = list(self.messages_)
states[i].cur_role = self.cur_role
states[i].chat_template = self.chat_template
state_group = ProgramStateGroup(states, self)
return state_group
##################################
########## Internal API ##########
##################################
def _append_node(self, other: SglExpr):
self.nodes.append(other)
other.prev_node = self.last_node
self.last_node = other
def _execute(self, other: SglExpr):
if isinstance(other, str):
other = SglConstantText(other)
other.pid = self.pid
if isinstance(other, SglConstantText):
self._execute_fill(other)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
else:
if self.only_trace_prefix:
raise StopTracing()
else:
self._append_node(other)
return self
def __iadd__(self, other):
self._execute(other)
return self
def _execute_fill(self, expr: SglConstantText):
if isinstance(expr, str):
expr = SglConstantText(expr)
self._append_node(expr)
def _execute_gen(self, expr: SglGen):
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_select(self, expr: SglSelect):
name = (
expr.name if expr.name is not None else "select_" + str(len(self.variables))
)
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(prefix)
def _execute_role_end(self, expr: SglRoleEnd):
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(suffix)
self.messages_.append({"role": expr.role, "content": ""})
self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(name, source=self.last_node)
self.variables[name] = new_node
def get_var(self, name):
ret = self.arguments.get(name, None)
if ret is not None:
return ret
v = self.variables[name]
return SglVariable(v.name, v.source)
def flatten_nodes(self):
def traverse(cur):
if isinstance(cur, SglExprList):
for child in cur.expr_list:
traverse(child)
else:
ret.append(cur)
ret = []
for x in self.nodes:
traverse(x)
return ret
def __del__(self):
pass
class TracingScope:
cur_scope = None
def __init__(self, tracer_state: TracerProgramState):
self.tracer_state = tracer_state
self.last_scope = TracingScope.cur_scope
def __enter__(self):
TracingScope.cur_scope = self
return self
def __exit__(self, exc_type, exc_value, traceback):
TracingScope.cur_scope = self.last_scope
@staticmethod
def get_current_scope():
return TracingScope.cur_scope
def add_child_state(self, state: TracerProgramState):
cur_scope = self
while cur_scope != None:
cur_scope.tracer_state.child_states.append(state)
cur_scope = cur_scope.last_scope

View File

@@ -0,0 +1,11 @@
import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, None)

View File

@@ -0,0 +1,385 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
from typing import List, NewType, Protocol
import interegular
from lark import Lark
# from outlines.fsm.parsing import PartialLark
from sglang.srt.constrained.regex import (
create_fsm_index_tokenizer,
make_deterministic_fsm,
)
from sglang.srt.constrained.tokenizer import Tokenizer
FSMState = NewType("FSMState", int)
class FSM(Protocol):
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
...
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
...
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
...
def reset(self) -> None:
...
class StopAtTokenFSM(FSM):
"""FSM to generate text until a specified token id is generated or
a specified number of tokens has been generated.
Text is usually produced until the EOS token is generated by the
model.
"""
def __init__(
self,
tokenizer: "Tokenizer",
stop_token_id: int,
):
self.stop_token_id = stop_token_id
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.final_states = {1}
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
In the final state the only allowed token is `stop_token_id`.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
if state == 0:
return list(self.vocabulary)
else:
return [self.stop_token_id]
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token
has been generated or the maximum number of tokens has been reached. In
which case the FSM moves to the final state `1`.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.stop_token_id:
return FSMState(1)
return FSMState(0)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def reset(self) -> None:
"""Reset the FSM to its initial state. Here this only resets the token counter."""
self.num_tokens_generated = 0
class RegexFSM(FSM):
"""FSM to generate text that is in the language of a regular expression."""
def __init__(
self,
regex_string: str,
tokenizer: "Tokenizer",
):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
(
self.states_to_token_maps,
self.empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in self.states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
self.final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a
map from authorized tokens to the state in which the FSM needs to move
if said token is generated. Therefore the authorized tokens at the
current state are the keys of the map returned by the value of the index
for current state.
If the current state is not contained in the end this means that we are
in a final state of the FSM. We only authorize EOS tokens in the final
state.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
next_tokens_to_end_states = self.states_to_token_maps.get(state)
if next_tokens_to_end_states is None:
return [self.end_token_id]
else:
return list(next_tokens_to_end_states.keys())
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
We use the index to determine to which state the FSM should transition
given the token that was just generated.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.end_token_id:
return FSMState(-1)
last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
if next_state is None:
next_state = -1
return FSMState(next_state)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def reset(self) -> None:
"""Reset the FSM to its initial state. Here this only resets the token counter."""
self.num_tokens_generated = 0
class CFGFSM(FSM):
"""FSM to generate text that is in the language of a context-free grammar."""
def __init__(
self,
cfg_string: str,
tokenizer: "Tokenizer",
):
# self.parser = PartialLark(cfg_string, parser="lalr")
self.parser = Lark(
cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
)
self.terminal_regexps = dict()
for terminal in self.parser.terminals:
if terminal.pattern is not None:
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
self.terminal_regexps["$END"] = tokenizer.eos_token
self.tokenizer = tokenizer
self.num_tokens_generated = 0
self.generations: List[str] = []
self.regex_fsms: List[RegexFSM] = []
self.reset_state: List[bool] = []
self.allow_eos: List[bool] = []
self.done: List[bool] = []
def _set_next_regex_fsm(self, idx: int = 0) -> None:
"""Use the CFG incremental parser to set the next regex FSM.
Check what the CFG incremental parser proposes next.
If the only proposal is the EOS token,
we set the state to done and return.
If there are other proposals,
we set a new regex FSM and return.
"""
interactive = self.parser.parse_interactive(self.generations[idx])
interactive.exhaust_lexer()
options = {self.terminal_regexps[x] for x in interactive.accepts()}
if self.terminal_regexps["$END"] in options:
options.remove(self.terminal_regexps["$END"])
if len(options) == 0:
self.done[idx] = True
return
self.allow_eos[idx] = True
options.add("")
assert len(options) > 1
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
args = (
regex_string,
self.tokenizer,
)
if len(self.regex_fsms) <= idx:
self.regex_fsms.append(RegexFSM(*args))
else:
self.regex_fsms[idx] = RegexFSM(*args)
self.reset_state[idx] = True
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Upon initialization, the CFG incremental parser is used to determine the first regex.
This regex is used for proposals until either:
- the regex is exhausted, and its only remaining option is the EOS token,
in which case we always transition to the next regex
- the regex can be exhausted, but the EOS token is not the only remaining option,
in which case we transition to the next regex with probability P (TODO)
or remove the possibility of generating the EOS token and continue with the current regex
The CFG incremental parser is allowed to propose the EOS token from any final state,
and once it is generated, the FSM will continue to always generate the EOS token.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
if len(self.generations) <= idx:
self.generations.append("")
self.reset_state.append(False)
self.allow_eos.append(False)
self.done.append(False)
if len(self.regex_fsms) > idx:
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.tokenizer.eos_token_id not in proposal:
return proposal
if set(proposal) != {self.tokenizer.eos_token_id}:
if False: # TODO: THIS NEEDS TO BE SAMPLED
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
return proposal
self._set_next_regex_fsm(idx)
if self.done[idx]:
return [self.tokenizer.eos_token_id]
if self.reset_state[idx]:
state = FSMState(0)
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.allow_eos[idx]:
self.allow_eos[idx] = False
else:
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
assert len(proposal) > 0
return proposal
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
Transitions the underlying regex FSM to its next state.
If at max tokens or EOS token, transition permanently to the final state.
Update stored partial generations for subsequent incremental parsing.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.tokenizer.eos_token_id:
self.done[idx] = True
return FSMState(-1)
if self.reset_state[idx]:
self.reset_state[idx] = False
state = FSMState(0)
self.generations[idx] += self.tokenizer.decode([token_id])[0]
return self.regex_fsms[idx].next_state(state, token_id, idx)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Return whether the current state of the FSM is a final state."""
return self.done[idx]
def reset(self) -> None:
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
self.num_tokens_generated = 0
self.generations = []
self.regex_fsms = []
self.reset_state = []
self.done = []

View File

@@ -0,0 +1,41 @@
import threading
from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer
def get_fsm(regex, tokenizer, fsm_cache_entry):
outlines_tokenizer = TransformerTokenizer(tokenizer)
fsm = RegexFSM(regex, outlines_tokenizer)
fsm_cache_entry.fsm = fsm
fsm_cache_entry.event.set()
class FSMCacheEntry:
def __init__(self):
self.fsm = None
self.event = threading.Event()
class FSMCache:
def __init__(self, tokenizer):
self.cache = {}
self.tokenizer = tokenizer
def init_fsm_in_background(self, regex):
if regex not in self.cache:
self.cache[regex] = FSMCacheEntry()
threading.Thread(
target=get_fsm,
args=(
regex,
self.tokenizer,
self.cache[regex],
),
).start()
def get_fsm(self, regex):
self.init_fsm_in_background(regex)
entry = self.cache[regex]
entry.event.wait()
return entry.fsm

View File

@@ -0,0 +1,586 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
from collections import namedtuple
from functools import lru_cache
from typing import Dict, Generator, List, Sequence, Set, Tuple
import numba
import numpy as np
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
from numba.typed.typedobjectutils import _nonoptional
from sglang.srt.constrained.tokenizer import Tokenizer
class BetterAlphabet(Alphabet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert anything_else in self._symbol_mapping
self.anything_value = self._symbol_mapping[anything_else]
def __getitem__(self, item):
return self._symbol_mapping.get(item, self.anything_value)
def copy(self):
return BetterAlphabet(self._symbol_mapping.copy())
class BetterFSM(FSM):
flat_transition_map: Dict[Tuple[int, int], int]
trans_key_to_states: Dict[int, List[int]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self.alphabet, BetterAlphabet):
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
flat_transition_map = {}
trans_key_to_states = {}
for from_state, trans_map in self.map.items():
for trans_key, to_state in trans_map.items():
flat_transition_map[(from_state, trans_key)] = to_state
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
self.__dict__["trans_key_to_states"] = trans_key_to_states
self.__dict__["flat_transition_map"] = flat_transition_map
self.__dict__["_fsm_info"] = None
def copy(self):
return BetterFSM(
alphabet=self.alphabet.copy(),
states=self.states.copy(),
initial=self.initial,
finals=self.finals.copy(),
map=self.map.copy(),
__no_validation__=True,
)
@property
def fsm_info(self):
if self._fsm_info is None:
flat_transition_map_items = np.fromiter(
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
dtype=np.dtype("i8, i8, i8"),
)
trans_key_to_states_items = np.fromiter(
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("i8, i8"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U1, i8"),
)
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
nb_finals,
flat_transition_map_items,
trans_key_to_states_items,
self.alphabet.anything_value,
alphabet_symbol_mapping_items,
)
return self._fsm_info
nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
@numba.njit(cache=True)
def create_fsm_info(
py_initial,
py_finals,
flat_transition_map_items,
trans_key_to_states_items,
py_anything_value,
alphabet_symbol_mapping_items,
):
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
for trans_key_and_state in trans_key_to_states_items:
trans_key_to_states.setdefault(
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
).append(trans_key_and_state[1])
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
for trans_key_and_state in flat_transition_map_items:
flat_transition_map[
(trans_key_and_state[0], trans_key_and_state[1])
] = trans_key_and_state[2]
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
initial = numba.int64(py_initial)
finals = set()
for final in py_finals:
finals.add(final)
anything_value = numba.int64(py_anything_value)
return FSMInfo(
initial,
finals,
flat_transition_map,
trans_key_to_states,
anything_value,
alphabet_symbol_map,
)
FSMInfo = namedtuple(
"FSMInfo",
[
"initial",
"finals",
"transitions",
"trans_key_to_states",
"alphabet_anything_value",
"alphabet_symbol_mapping",
],
)
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
"""Construct an equivalent FSM with deterministic state labels."""
old_to_new_trans_keys = {
trans_key: i
for i, (trans_key, _) in enumerate(
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
)
}
new_symbol_mapping = {
symbol: old_to_new_trans_keys[trans_key]
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
}
new_alphabet = BetterAlphabet(new_symbol_mapping)
new_map = {
from_state: {
old_to_new_trans_keys[trans_key]: to_state
for trans_key, to_state in trans_map.items()
}
for from_state, trans_map in fsm.map.items()
}
old_to_new_states = {}
old_to_new_states[fsm.initial] = 0
i = 0
seen = {fsm.initial}
old_state_queue = [fsm.initial]
while old_state_queue:
old_state = old_state_queue.pop(-1)
transitions = new_map[old_state]
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
for _, old_state in sorted_transitions:
if old_state not in seen:
old_state_queue.append(old_state)
seen.add(old_state)
if old_state not in old_to_new_states:
i += 1
old_to_new_states[old_state] = i
new_map = dict(
sorted(
(
(
old_to_new_states[from_state],
dict(
sorted(
(
(trans_key, old_to_new_states[to_state])
for trans_key, to_state in trans_map.items()
),
key=lambda v: v[0],
)
),
)
for from_state, trans_map in new_map.items()
),
key=lambda v: v[0],
)
)
new_initial = 0
new_finals = frozenset(
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
)
new_states = frozenset(sorted(new_map.keys()))
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
return new_fsm, old_to_new_states
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)
for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
new_state = fsm_transitions.get((state, trans_key))
if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]
return numba.typed.List.empty_list(numba.int64)
state = new_state
if state in fsm_finals:
last_final_idx = numba.uint64(i + 1)
accepted_states.append(_nonoptional(state))
if full_match and last_final_idx - 1 != i:
return numba.typed.List.empty_list(numba.int64)
return accepted_states
def walk_fsm(
fsm: BetterFSM,
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
fsm_finals = fsm.finals
state = start_state
accepted_states: List[int] = []
last_final_idx: int = 0
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map
for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
new_state = fsm_transitions.get((state, trans_key))
if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]
return []
state = new_state
if state in fsm_finals:
last_final_idx = i + 1
accepted_states.append(state)
if full_match and last_final_idx - 1 != i:
return []
return accepted_states
def fsm_union(
fsms: Sequence[FSM],
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
"""Construct an FSM representing the union of the FSMs in `fsms`.
This is an updated version of `interegular.fsm.FSM.union` made to return an
extra map of component FSMs to the sets of state transitions that
correspond to them in the new FSM.
"""
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
indexed_fsms = tuple(enumerate(fsms))
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
# Dedicated function accepting a "superset" and returning the next
# "superset" obtained by following this transition in the new FSM
def follow(current_state, new_transition: int):
next = {}
for i, f in indexed_fsms:
old_transition = new_to_old[i][new_transition]
if (
i in current_state
and current_state[i] in f.map
and old_transition in f.map[current_state[i]]
):
next[i] = f.map[current_state[i]][old_transition]
if not next:
raise OblivionError
return next
states = [initial]
finals: Set[int] = set()
map: Dict[int, Dict[int, int]] = {}
# Map component FSMs to their new state-to-state transitions, finals, and a
# map translating component FSM states to aggregate FSM states
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
] = {}
i = 0
while i < len(states):
state = states[i]
# Add to the finals of the aggregate FSM whenever we hit a final in a
# component FSM
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
finals.add(i)
# Compute the map for this state
map[i] = {}
for transition in alphabet.by_transition:
try:
next = follow(state, transition)
except OblivionError:
# Reached an oblivion state; don't list it
continue
else:
try:
# TODO: Seems like this could--and should--be avoided
j = states.index(next)
except ValueError:
j = len(states)
states.append(next)
map[i][transition] = j
for fsm_id, fsm_state in next.items():
(
fsm_transitions,
fsm_finals,
fsm_old_to_new,
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
old_from = state[fsm_id]
old_to = fsm_state
fsm_old_to_new.setdefault(old_from, set()).add(i)
fsm_old_to_new.setdefault(old_to, set()).add(j)
fsm_transitions.add((i, j))
if fsm_state in fsms[fsm_id].finals:
fsm_finals.add(j)
i += 1
fsm = FSM(
alphabet=alphabet,
states=range(len(states)),
initial=0,
finals=finals,
map=map,
__no_validation__=True,
)
fsm, old_to_new_states = make_deterministic_fsm(fsm)
_fsms_to_trans_finals = {
fsm_id: (
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
{old_to_new_states[s] for s in finals},
{
old_state: {old_to_new_states[new_state] for new_state in new_states}
for old_state, new_states in old_to_new.items()
},
)
for fsm_id, (transitions, finals, old_to_new) in sorted(
fsms_to_trans_finals.items(), key=lambda x: x[0]
)
}
return (
fsm,
_fsms_to_trans_finals,
)
def get_sub_fsms_from_seq(
state_seq: Sequence[int],
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
],
) -> Generator[Tuple[int, bool, bool], None, None]:
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
Parameters
----------
state_seq
A state sequence.
fsms_to_trans_finals
A map from FSM indices to tuples containing sets of their state transitions
and sets of the final/accept states.
Returns
-------
A generator returning tuples containing each sub-FSM index (in the order
they were union-ed to construct `fsm`) and booleans indicating whether or
not there is another valid transition from the last state in the sequence
for the associated sub-FSM (i.e. if the FSM can continue
accepting/matching) and whether or not the sequence ends in a final state
of the sub-FSM.
"""
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
last_fsm_state = state_seq[-1]
yield from (
(
# The sub-FMS index
fsm_idx,
# Is there another possible transition in this sub-FSM?
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
# Is this sub-FSM in a final state?
state_seq[-1] in finals,
)
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
if state_seq_transitions.issubset(transitions)
)
@numba.njit(cache=True, nogil=True)
def state_scan_tokens(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: Dict[str, List[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()
for token, token_ids in vocabulary.items():
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
start_state,
False,
)
if state_seq is not None and len(state_seq) < len(token):
continue
for token_id in token_ids:
res.add((token_id, state_seq[-1]))
return res
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: Dict[str, List[int]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
# code, too.
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
seen: Set[int] = set()
next_states = {fsm_info.initial}
while next_states:
start_state = next_states.pop()
token_ids_end_states = state_scan_tokens(
fsm_info.transitions,
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
fsm_info.initial,
fsm_info.finals,
vocabulary,
start_state,
)
for token_id_and_end_state in token_ids_end_states:
states_to_token_subsets.setdefault(start_state, set()).add(
token_id_and_end_state
)
end_state = token_id_and_end_state[1]
if end_state not in seen:
next_states.add(end_state)
seen.add(start_state)
return states_to_token_subsets
# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(tokenizer: "Tokenizer"):
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
vocabulary = numba.typed.Dict.empty(
numba.types.string, numba.types.ListType(numba.int64)
)
empty_token_ids = set()
for token, token_idx in tokenizer.vocabulary.items():
if token in tokenizer.special_tokens:
continue
token_str = tokenizer.convert_token_to_string(token)
if token_str:
vocabulary.setdefault(
token_str,
numba.typed.List.empty_list(numba.int64),
).append(numba.int64(token_idx))
else:
empty_token_ids.add(numba.int64(token_idx))
return vocabulary, empty_token_ids
def create_fsm_index_tokenizer(
fsm: BetterFSM,
tokenizer: "Tokenizer",
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
.. warning::
`fsm` needs to be deterministically ordered so that future caching makes sense.
"""
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
# Allow transitions to EOS from all terminals FSM states that are
# reachable
# TODO: Do we really need this anymore?
for state in fsm.fsm_info.finals:
subset = states_to_token_subsets.get(state)
if subset is not None:
subset.add((tokenizer.eos_token_id, state))
# Convert to token-to-end-state maps
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
return states_to_token_subsets, empty_token_ids

View File

@@ -0,0 +1,266 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Dict,
Hashable,
List,
Optional,
Protocol,
Set,
Tuple,
Union,
)
import numpy as np
import torch
from numpy.typing import NDArray
class Tokenizer(Protocol, Hashable):
eos_token: str
eos_token_id: int
pad_token_id: int
vocabulary: Dict[str, int]
special_tokens: Set[int]
@abstractmethod
def encode(
self, prompt: Union[str, List[str]]
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
...
@abstractmethod
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
"""Translate an array of token ids to a string or list of strings."""
...
@abstractmethod
def convert_token_to_string(self, token: str) -> str:
"""Convert a token to its equivalent string.
This is for instance useful for BPE tokenizers where whitespaces are
represented by the special characted `Ġ`. This prevents matching a raw
token that includes `Ġ` with a string.
"""
...
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
__all__ = ["transformers"]
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
def get_llama_tokenizer_types():
"""Get all the Llama tokenizer types/classes that need work-arounds.
When they can't be imported, a dummy class is created.
"""
try:
from transformers.models.llama import LlamaTokenizer
except ImportError:
class LlamaTokenizer: # type: ignore
pass
try:
from transformers.models.llama import LlamaTokenizerFast
except ImportError:
class LlamaTokenizerFast: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizer
except ImportError:
class CodeLlamaTokenizer: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizerFast
except ImportError:
class CodeLlamaTokenizerFast: # type: ignore
pass
return (
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
)
class Transformer:
"""Represents a `transformers` model."""
def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
):
self.device = model.device
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
assert 0 < input_ids.ndim < 3
if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
return output.logits, output.past_key_values
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> torch.FloatTensor:
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]
return next_token_logits, kv_cache
class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""
def __init__(self, tokenizer):
# TODO: Do something to make this hashable?
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token
if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token
self.special_tokens = set(self.tokenizer.all_special_tokens)
self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]
def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text
def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = self.tokenizer.convert_tokens_to_string([token])
if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def __eq__(self, other):
if isinstance(other, type(self)):
return False
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
# return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented
def __hash__(self):
from datasets.fingerprint import Hasher
return hash(Hasher.hash(self.tokenizer))
def transformers(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if device is not None:
model_kwargs["device_map"] = device
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
return Transformer(model, tokenizer)

View File

@@ -0,0 +1,164 @@
"""Utilities for Huggingface Transformers."""
import json
import os
import warnings
from typing import List, Optional, Tuple, Union
from huggingface_hub import snapshot_download
from sglang.srt.utils import is_multimodal_model
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
def download_from_hf(model_path: str):
if os.path.exists(model_path):
return model_path
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
def get_config_json(model_path: str):
with open(os.path.join(model_path, "config.json")) as f:
config = json.load(f)
return config
def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision
)
return config
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_position_embeddings",
"max_seq_len",
"model_max_length",
]
def get_context_length(config):
"""Get the context length of a model from a huggingface model config."""
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = config.rope_scaling["factor"]
else:
rope_scaling_factor = 1
for key in CONTEXT_LENGTH_KEYS:
val = getattr(config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if is_multimodal_model(tokenizer_name):
processor = get_processor(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
tokenizer = processor.tokenizer
return tokenizer
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if (
"llama" in tokenizer_name.lower()
and kwargs.get("use_fast", True)
and tokenizer_name != _FAST_LLAMA_TOKENIZER
):
pass
# warnings.warn(
# "For some LLaMA V1 models, initializing the fast tokenizer may "
# "take a long time. To reduce the initialization time, consider "
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
# "tokenizer."
# )
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = (
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
"original tokenizer."
)
raise RuntimeError(err_msg) from e
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
else:
raise e
if not isinstance(tokenizer, PreTrainedTokenizerFast):
warnings.warn(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
return tokenizer
def get_processor(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
):
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
return processor

View File

@@ -0,0 +1,181 @@
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
B_Start_Loc,
B_Seqlen,
Out,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
kv_group_num: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
other=0.0,
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
other=0.0,
)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
cached_kernel = None
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
)
return
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
b_start_loc,
b_seq_len,
o,
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
o.stride(0),
o.stride(1),
kv_group_num=kv_group_num,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)

View File

@@ -0,0 +1,371 @@
import torch
import triton
import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
@triton.jit
def _fwd_kernel(
Q_Extend,
K_Extend,
V_Extend,
O_Extend,
K_Buffer,
V_Buffer,
Req_to_tokens,
B_req_idx,
B_Seq_Len,
B_Start_Loc_Extend,
B_Seq_Len_Extend,
sm_scale,
kv_group_num,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_req_to_tokens_b,
BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
cur_block_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
cur_seq_prefix_start_in_loc = 0
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
offs_q = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
# stage1: compute scores with prefix
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
cur_seq_prefix_start_in_loc + start_n + offs_n
)
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
# load k in transposed way
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_d[None, :]
)
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
# stage2: compute the trianlge part
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
for start_n in range(0, cur_block_m_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
# load k in transposed way
offs_k = (
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None]
)
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_v = (
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_d[None, :]
)
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
offs_o = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
)
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
def extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
BLOCK_M, BLOCK_N = 128, 128
Lq, Lk, Lv, Lo = (
q_extend.shape[-1],
k_extend.shape[-1],
v_extend.shape[-1],
o_extend.shape[-1],
)
assert Lq == Lk and Lk == Lv and Lv == Lo
assert Lq in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5)
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
_fwd_kernel[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
BLOCK_DMODEL=Lq,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
def redundant_attention(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
):
total_token_num = k_buffer.shape[0]
B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1]
q_buffer = torch.empty(
(total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device
)
pt = 0
for i in range(B):
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend]
pt += cur_seq_len_extend
o_buffer = torch.empty_like(q_buffer)
context_attention_fwd(
q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch
)
pt = 0
for i in range(B):
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
pt += cur_seq_len_extend
def test():
torch.manual_seed(0)
B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
dtype = torch.float16
b_seq_len_prefix = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
for i in range(B):
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
b_seq_len_extend = b_seq_len - b_seq_len_prefix
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
)
redundant_attention(
q_extend,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
)
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,79 @@
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_segmented_gather(
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
cur_req = tl.program_id(0)
cur_l = tl.load(len_add_1 + cur_req)
cum_l = tl.load(cum_len + cur_req)
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = off < cur_l - 1
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
cached_kernel = None
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
voc_size = all_logits.shape[1]
grid = (len_add_1.shape[0], 1, 1)
max_seq_len = len_add_1.max().item()
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
4,
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
)
return
_fwd_segmented_gather[grid](
all_logits,
len_add_1,
cum_len,
input_ids,
logprobs,
max_seq_len,
voc_size,
BLOCK_SIZE=128,
)
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
if __name__ == "__main__":
all_logits = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
)
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
print(logprobs)
# assert logprobs == [2, 2, 4]

View File

@@ -0,0 +1,77 @@
import torch
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from torch import nn
from vllm.model_executor.parallel_utils.communication_op import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
class LogitsProcessor(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
def forward(self, input_ids, hidden_states, weight, input_metadata):
if not input_metadata.return_normalized_logprob:
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
last_hidden = hidden_states[last_index]
hidden_states = None
last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, None
else:
assert input_metadata.forward_mode != ForwardMode.DECODE
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
logits = tensor_model_parallel_all_gather(logits)
logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
normalized_logprobs = compute_normalized_logprobs(
all_logprobs,
input_metadata.seq_lens - input_metadata.prefix_lens,
input_ids,
)
last_logits = logits[last_index]
return last_logits, normalized_logprobs
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
logprobs = torch.zeros(
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
)
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs)
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
end = torch.cumsum(len_add_1.sub_(1), dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
end.sub_(1)
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
res = sum_logp / len_add_1
return res

View File

@@ -0,0 +1,158 @@
from typing import List
import torch
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from torch import nn
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
class RadixAttention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scaling,
num_kv_heads,
layer_id,
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.layer_id = layer_id
from sglang.srt.managers.router.model_runner import global_model_mode
self.use_flashinfer = "flashinfer" in global_model_mode
if self.use_flashinfer:
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
else:
self.prefill_forward = self.prefill_forward_triton
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
context_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
k,
v,
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
)
self.store_kv_cache(k, v, input_metadata)
return o
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.prefix_lens,
input_metadata.extend_start_loc,
input_metadata.extend_seq_lens,
input_metadata.max_seq_len,
input_metadata.max_extend_len,
)
return o
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
token_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
o.view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens,
)
return o
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
o = input_metadata.prefill_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.qo_indptr,
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
allow_fp16_qk_reduction=True,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
self.store_kv_cache(k, v, input_metadata)
o = input_metadata.decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
def forward(self, q, k, v, input_metadata: InputMetadata):
k = k.view(-1, self.tp_k_head_num, self.head_dim)
v = v.view(-1, self.tp_v_head_num, self.head_dim)
if input_metadata.forward_mode == ForwardMode.PREFILL:
return self.prefill_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode == ForwardMode.EXTEND:
return self.extend_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
if input_metadata.out_cache_loc is not None:
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer[input_metadata.out_cache_loc] = cache_v
elif input_metadata.out_cache_cont_start is not None:
key_buffer[
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
] = cache_k
value_buffer[
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
] = cache_v
else:
raise RuntimeError()

View File

@@ -0,0 +1,324 @@
# Adapted from
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
import torch
import triton
import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
@triton.jit
def _fwd_kernel_stage1(
Q,
K_Buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
att_stride_h,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_start_index = 0
cur_batch_end_index = cur_batch_seq_len
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q + off_q + start_mark)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
offs_buf_k = (
k_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[None, :]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=offs_n_new[:, None] < cur_batch_end_index,
other=0.0,
)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
@triton.jit
def _fwd_kernel_stage2(
Logics,
V_Buffer,
Out,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
stride_logic_h,
stride_buf_vbs,
stride_buf_vh,
stride_obs,
stride_oh,
stride_req_to_token_b,
other_kv_index, # To fix a NAN issue
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
v_ptrs = V_Buffer + offs_buf_v
e_max = float("-inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
v_index = tl.load(
Req_to_tokens
+ cur_batch_req_idx * stride_req_to_token_b
+ (start_n + offs_n),
mask=(start_n + offs_n) < cur_batch_seq_len,
other=other_kv_index,
)
qk = tl.load(
Logics
+ cur_head * stride_logic_h
+ (cur_batch_start_loc + start_n + offs_n),
mask=start_n + offs_n < cur_batch_seq_len,
other=float("-inf"),
)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
old_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
e_sum = e_sum * old_scale + tl.sum(p, 0)
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
e_max = n_e_max
acc = acc / e_sum
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
cached_kernel_stage1 = None
cached_kernel_stage2 = None
def _token_att_m_fwd(
q,
k_buffer,
att_out,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
max_len_in_batch,
):
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lk**0.5)
batch, head_num = B_req_idx.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
kv_group_num = q.shape[1] // k_buffer.shape[1]
if kv_group_num == 1:
num_warps = 4
else:
num_warps = 2
global cached_kernel_stage1
if cached_kernel_stage1:
cached_kernel_stage1(
grid,
num_warps,
q,
k_buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
att_out.stride(0),
)
return
_fwd_kernel_stage1[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
att_out.stride(0),
kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
def _token_softmax_reducev_fwd(
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
other_kv_index,
):
BLOCK = 64
batch, head = b_seq_len.shape[0], logics.shape[0]
grid = (batch, head, 1)
kv_group_num = logics.shape[0] // v_buffer.shape[1]
num_warps = 1
global cached_kernel_stage2
if cached_kernel_stage2:
cached_kernel_stage2(
grid,
num_warps,
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
other_kv_index,
)
return
_fwd_kernel_stage2[grid](
logics,
v_buffer,
o,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
logics.stride(0),
v_buffer.stride(0),
v_buffer.stride(1),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
other_kv_index,
kv_group_num=kv_group_num,
BLOCK_DMODEL=v_buffer.shape[-1],
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=3,
)
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
def token_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
other_kv_index,
total_num_tokens,
att_m=None,
):
if att_m is None:
att_m = torch.empty(
(q.shape[-2], total_num_tokens), dtype=q.dtype, device="cuda"
)
_token_att_m_fwd(
q,
k_buffer,
att_m,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
)
_token_softmax_reducev_fwd(
att_m,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
other_kv_index,
)

View File

@@ -0,0 +1,85 @@
import asyncio
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class DetokenizerManager:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
context = zmq.asyncio.Context(2)
self.recv_from_router = context.socket(zmq.PULL)
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_router.recv_pyobj()
if isinstance(recv_obj, BatchTokenIDOut):
output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens per request
output_strs = self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)):
if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
if first_token.startswith(""):
output_strs[i] = " " + output_strs[i]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
recv_obj.rids,
output_strs,
recv_obj.meta_info,
recv_obj.finished,
)
)
else:
raise ValueError(f"Invalid object: {recv_obj}")
def start_detokenizer_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
try:
manager = DetokenizerManager(server_args, port_args)
except Exception as e:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
loop.run_until_complete(manager.handle_loop())

View File

@@ -0,0 +1,88 @@
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from sglang.srt.sampling_params import SamplingParams
@dataclass
class GenerateReqInput:
text: Union[List[str], str]
image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False
def post_init(self):
is_single = isinstance(self.text, str)
if is_single:
if self.sampling_params is None:
self.sampling_params = {}
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.return_normalized_logprob is None:
self.return_normalized_logprob = False
if self.normalized_logprob_start_len is None:
self.normalized_logprob_start_len = 0
else:
num = len(self.text)
if self.image_data is None:
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num
if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
else:
assert isinstance(self.rid, list)
if self.return_normalized_logprob is None:
self.return_normalized_logprob = [False] * num
elif not isinstance(self.return_normalized_logprob, list):
self.return_normalized_logprob = [self.return_normalized_logprob] * num
if self.normalized_logprob_start_len is None:
self.normalized_logprob_start_len = [0] * num
elif not isinstance(self.normalized_logprob_start_len, list):
self.normalized_logprob_start_len = [
self.normalized_logprob_start_len
] * num
@dataclass
class TokenizedGenerateReqInput:
rid: str
input_ids: List[int]
pixel_values: List[float]
image_hash: int
sampling_params: SamplingParams
return_normalized_logprob: bool
normalized_logprob_start_len: int
stream: bool
@dataclass
class BatchTokenIDOut:
rids: List[str]
output_tokens: List[List[int]]
hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool]
meta_info: List[Dict]
finished: List[bool]
@dataclass
class BatchStrOut:
rids: List[str]
output_str: List[str]
meta_info: List[Dict]
finished: List[bool]

View File

@@ -0,0 +1,12 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Union
@dataclass
class CompletionRequest:
prompt: Union[str, List[Any]]
model: str = "default"
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 16
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None

View File

@@ -0,0 +1,326 @@
from enum import Enum, auto
from typing import List
import numpy as np
import torch
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
class ForwardMode(Enum):
PREFILL = auto()
EXTEND = auto()
DECODE = auto()
class FinishReason(Enum):
LENGTH = auto()
EOS_TOKEN = auto()
STOP_STR = auto()
class Req:
def __init__(self, rid):
self.rid = rid
self.input_ids = []
self.output_ids = []
self.pixel_values = None
self.image_offset = 0
self.sampling_params = None
self.return_normalized_logprob = False
self.normalized_logprob_start_len = 0
self.stream = False
self.tokenizer = None
self.finished = False
self.finish_reason = None
self.hit_stop_str = None
self.adjust_input_len = 0
self.prefix_indices = []
self.normalized_logprob = None
# for constrained decoding
self.regex_fsm = None
self.regex_fsm_state = None
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
def check_finished(self):
if self.finished:
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished = True
self.finish_reason = FinishReason.LENGTH
return
if (
self.output_ids[-1] == self.tokenizer.eos_token_id
and self.sampling_params.ignore_eos == False
):
self.finished = True
self.finish_reason = FinishReason.EOS_TOKEN
return
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str:
self.finished = True
self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str
return
def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
class Batch:
def __init__(
self,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: TokenToKVPool,
tree_cache: RadixCache,
):
self.reqs = reqs
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.tree_cache = tree_cache
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in reqs
)
def is_empty(self):
return len(self.reqs) == 0
def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda"
bs = len(self.reqs)
reqs = self.reqs
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
prefix_indices = [r.prefix_indices for r in reqs]
# Handle prefix
flatten_input_ids = []
extend_lens = []
prefix_lens = []
seq_lens = []
req_pool_indices = self.req_to_token_pool.alloc(bs)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
for i in range(bs):
flatten_input_ids.extend(input_ids[i])
extend_lens.append(len(input_ids[i]))
if len(prefix_indices[i]) == 0:
prefix_lens.append(0)
else:
prefix_lens.append(len(prefix_indices[i]))
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
: len(prefix_indices[i])
] = prefix_indices[i]
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
# Alloc mem
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
print("Prefill out of memory.")
self.tree_cache.pretty_print()
exit()
pt = 0
for i in range(bs):
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i]
# Handle logit bias
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
for i in range(bs):
if reqs[i].sampling_params.dtype == "int":
logit_bias[i] = int_token_logit_bias
# Set fields
self.input_ids = torch.tensor(
flatten_input_ids, dtype=torch.int32, device=device
)
self.pixel_values = [r.pixel_values for r in reqs]
self.image_offsets = [
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
]
self.req_pool_indices = req_pool_indices
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.position_ids_offsets = position_ids_offsets
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
self.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
).view(-1, 1)
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
).view(-1, 1)
self.frequency_penalties = torch.tensor(
[r.sampling_params.frequency_penalty for r in reqs],
dtype=torch.float,
device=device,
)
self.presence_penalties = torch.tensor(
[r.sampling_params.presence_penalty for r in reqs],
dtype=torch.float,
device=device,
)
self.logit_bias = logit_bias
def update_for_decode(self, input_ids=None):
if input_ids is None:
input_ids = [
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
]
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.seq_lens.add_(1)
self.prefix_lens = None
# Alloc mem
bs = len(self.reqs)
alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
if alloc_res is None:
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
print("Decode out of memory.")
self.tree_cache.pretty_print()
exit()
self.out_cache_cont_start = None
self.out_cache_cont_end = None
else:
self.out_cache_loc = alloc_res[0]
self.out_cache_cont_start = alloc_res[1]
self.out_cache_cont_end = alloc_res[2]
self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens - 1
] = self.out_cache_loc
def filter_batch(self, unfinished_indices: List[int]):
self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
self.seq_lens = self.seq_lens[new_indices]
self.input_ids = None
self.req_pool_indices = self.req_pool_indices[new_indices]
self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
for item in [
"temperatures",
"top_ps",
"top_ks",
"frequency_penalties",
"presence_penalties",
"logit_bias",
]:
setattr(self, item, getattr(self, item)[new_indices])
def merge(self, other):
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.prefix_lens = None
self.position_ids_offsets = torch.concat(
[self.position_ids_offsets, other.position_ids_offsets]
)
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
for item in [
"temperatures",
"top_ps",
"top_ks",
"frequency_penalties",
"presence_penalties",
"logit_bias",
]:
setattr(
self, item, torch.concat([getattr(self, item), getattr(other, item)])
)
def sample(self, logits: torch.Tensor):
# Post process logits
logits = logits.contiguous()
logits.div_(self.temperatures)
logits.add_(self.logit_bias)
has_regex = any(req.regex_fsm is not None for req in self.reqs)
if has_regex:
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
for i, req in enumerate(self.reqs):
if req.regex_fsm is not None:
allowed_mask.zero_()
allowed_mask[
req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
] = 1
logits[i].masked_fill_(~allowed_mask, float("-inf"))
# TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
sampled_index = torch.multinomial(probs_sort, num_samples=1)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
-1
)
batch_next_token_probs = torch.gather(
probs_sort, dim=1, index=sampled_index
).view(-1)
if has_regex:
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
for i, req in enumerate(self.reqs):
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.next_state(
req.regex_fsm_state, batch_next_token_ids_cpu[i]
)
return batch_next_token_ids, batch_next_token_probs
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
return probs_sort, probs_idx

View File

@@ -0,0 +1,71 @@
import asyncio
import logging
from typing import List, Tuple
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class RouterManager:
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
# Init communication
context = zmq.asyncio.Context(2)
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
# Init status
self.model_client = model_client
self.recv_reqs = []
async def loop_for_forward(self):
while True:
next_step_input = list(self.recv_reqs)
self.recv_reqs = []
out_pyobjs = await self.model_client.step(next_step_input)
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# await for a while to accept input requests
await asyncio.sleep(0.001)
async def loop_for_recv_requests(self):
while True:
recv_req = await self.recv_from_tokenizer.recv_pyobj()
self.recv_reqs.append(recv_req)
def start_router_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
model_client = ModelRpcClient(server_args, port_args)
router = RouterManager(model_client, port_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(router.loop_for_recv_requests())
loop.run_until_complete(router.loop_for_forward())

View File

@@ -0,0 +1,497 @@
import asyncio
import logging
import multiprocessing
import time
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import rpyc
import torch
from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_exception_traceback,
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
)
logger = logging.getLogger("model_rpc")
class ModelRpcServer(rpyc.Service):
def exposed_init_model(
self,
tp_rank: int,
server_args: ServerArgs,
port_args: PortArgs,
):
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
# Copy arguments
self.model_mode = server_args.model_mode
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path, server_args.trust_remote_code
)
self.model_runner = ModelRunner(
self.model_config,
server_args.mem_fraction_static,
tp_rank,
server_args.tp_size,
port_args.nccl_port,
server_args.load_format,
server_args.trust_remote_code,
server_args.model_mode,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.eos_token_id = self.tokenizer.eos_token_id
self.max_total_num_token = self.model_runner.max_total_num_token
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len, self.max_total_num_token // 6
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
set_random_seed(server_args.random_seed)
logger.info(
f"Rank {self.tp_rank}: "
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
f"model_mode={self.model_mode}"
)
# Init cache
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
self.scheduler = Scheduler(
self.schedule_heuristic,
self.max_num_running_seq,
self.max_prefill_num_token,
self.max_total_num_token,
self.tree_cache,
)
self.req_to_token_pool = self.model_runner.req_to_token_pool
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
# Init running status
self.forward_queue: List[Req] = []
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = 2
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer)
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
recv_reqs = obtain(recv_reqs)
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
# Return results
ret = self.out_pyobjs
self.out_pyobjs = []
return ret
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_fill_batch()
if new_batch is not None:
# Run new fill batch
self.forward_fill_batch(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
else:
# Run decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(10):
self.forward_decode_batch(self.running_batch)
if self.running_batch.is_empty():
self.running_batch = None
break
if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20:
self.decode_forward_ct = 0
num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
logger.info(
f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_token:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
)
def handle_generate_request(
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid)
req.input_ids = recv_req.input_ids
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value
)
req.sampling_params = recv_req.sampling_params
req.return_normalized_logprob = recv_req.return_normalized_logprob
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# init the regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
# Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.input_ids),
)
self.forward_queue.append(req)
def get_new_fill_batch(self):
if (
self.running_batch is not None
and len(self.running_batch.reqs) > self.max_num_running_seq
):
return None
for req in self.forward_queue:
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_normalized_logprob:
prefix_indices = prefix_indices[: req.normalized_logprob_start_len]
req.adjust_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices
req.last_node = last_node
# Get priority queue
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
# Add requests if there is available space
can_run_list = []
new_batch_total_tokens = 0
new_batch_input_tokens = 0
new_batch_prefix_tokens = 0
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
new_ratio = self.scheduler.new_token_estimation_ratio()
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
for r in self.running_batch.reqs
]
)
for req in self.forward_queue:
if req.return_normalized_logprob:
# Need at least two tokens to compute normalized logprob
if req.adjust_input_len < 2:
delta = 2 - req.adjust_input_len
req.adjust_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.adjust_input_len == 0 and req.max_new_tokens() > 0:
# Need at least one token to compute logits
req.adjust_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
if (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and req.adjust_input_len + new_batch_input_tokens
< self.max_prefill_num_token
):
delta = self.tree_cache.inc_ref_counter(req.last_node)
available_size += delta
if not (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
):
delta = self.tree_cache.dec_ref_counter(req.last_node)
available_size += delta
else:
self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req)
new_batch_total_tokens += (
req.adjust_input_len + req.max_new_tokens()
)
new_batch_input_tokens += req.adjust_input_len
if len(can_run_list) == 0:
return None
if self.tp_rank == 0:
logger.info(
f"new fill batch. #seq: {len(can_run_list)}. "
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
f"#new_token: {new_batch_input_tokens}. "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
new_batch = Batch(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
)
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch
def forward_fill_batch(self, batch: Batch):
# Build batch tensors
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
if batch.extend_num_tokens != 0:
# Forward
logits, normalized_logprobs = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_normalized_logprob
)
# print("extend logits", logits)
if normalized_logprobs is not None:
normalized_logprobs = normalized_logprobs.cpu().tolist()
next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
normalized_logprobs = None
# Check finish condition
reqs = batch.reqs
for i in range(len(reqs)):
reqs[i].output_ids = [next_token_ids[i]]
reqs[i].check_finished()
if normalized_logprobs is not None:
reqs[i].normalized_logprob = normalized_logprobs[i]
self.handle_finished_requests(batch)
def forward_decode_batch(self, batch: Batch):
# Update batch tensors
self.decode_forward_ct += 1
batch.update_for_decode()
# Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
# Check finish condition
reqs = batch.reqs
for i in range(len(reqs)):
reqs[i].output_ids.append(next_token_ids[i])
reqs[i].check_finished()
self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: Batch):
output_rids = []
output_tokens = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_meta_info = []
output_finished = []
finished_indices = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if req.finished:
finished_indices.append(i)
else:
unfinished_indices.append(i)
if req.finished or (
req.stream and self.decode_forward_ct % self.stream_interval == 0
):
output_rids.append(req.rid)
output_tokens.append(req.output_ids)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
meta_info = {
"prompt_tokens": len(req.input_ids),
"completion_tokens": len(req.output_ids),
}
if req.return_normalized_logprob:
meta_info["normalized_logprob"] = req.normalized_logprob
output_meta_info.append(meta_info)
output_finished.append(req.finished)
# Send to detokenizer
if output_rids:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
output_tokens,
output_hit_stop_str,
output_skip_special_tokens,
output_meta_info,
output_finished,
)
)
# Remove finished reqs
if finished_indices:
# Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
for i in finished_indices:
req = batch.reqs[i]
req_pool_idx = req_pool_indices_cpu[i]
token_ids = tuple(req.input_ids + req.output_ids)
seq_len = len(token_ids) - 1
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
self.token_to_kv_pool.free(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
# Update batch tensors
if unfinished_indices:
batch.filter_batch(unfinished_indices)
else:
batch.reqs = []
class ModelRpcClient:
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
tp_size = server_args.tp_size
if tp_size == 1:
# Init model
self.model_server = ModelRpcServer()
self.model_server.exposed_init_model(0, server_args, port_args)
# Wrap functions
def async_wrap(f):
async def _func(*args, **kwargs):
return f(*args, **kwargs)
return _func
self.step = async_wrap(self.model_server.exposed_step)
else:
with ThreadPoolExecutor(tp_size) as executor:
# Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.model_servers = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
# Init model
def init_model(i):
return self.model_servers[i].init_model(i, server_args, port_args)
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
# Wrap functions
def async_wrap(func_name):
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
async def _func(*args, **kwargs):
tasks = [f(*args, **kwargs) for f in fs]
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
return obtain(tasks[0].value)
return _func
self.step = async_wrap("step")
def start_model_process(port):
def _init_service(port):
t = ThreadedServer(
ModelRpcServer(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 600},
)
t.start()
proc = multiprocessing.Process(target=_init_service, args=(port,))
proc.start()
time.sleep(1)
repeat_count = 0
while repeat_count < 20:
try:
con = rpyc.connect(
"localhost",
port,
config={"allow_pickle": True, "sync_request_timeout": 600},
)
break
except ConnectionRefusedError:
time.sleep(1)
repeat_count += 1
if repeat_count == 20:
raise RuntimeError("init rpc env error!")
assert proc.is_alive()
return con.root, proc

View File

@@ -0,0 +1,458 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import List
import numpy as np
import torch
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
from sglang.utils import get_available_gpu_memory
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
# for model_mode
global_model_mode: List[str] = []
@dataclass
class InputMetadata:
model_runner: "ModelRunner"
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
max_seq_len: int
req_pool_indices: torch.Tensor
start_loc: torch.Tensor
seq_lens: torch.Tensor
prefix_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
# for extend
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
max_extend_len: int = 0
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
return_normalized_logprob: bool = False
# for flashinfer
use_flashinfer: bool = False
qo_indptr: torch.Tensor = None
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_last_page_len: torch.Tensor = None
prefill_wrapper = None
decode_wrapper = None
def init_flashinfer_args(self, tp_size):
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
self.kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
]
for i in range(self.batch_size)
],
dim=0,
).contiguous()
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
from flashinfer.ops import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
):
self.qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
self.prefill_wrapper.begin_forward(
self.qo_indptr,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_last_page_len,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim,
1,
"NONE",
"float16",
)
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
return_normalized_logprob=False,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
total_num_tokens = int(torch.sum(seq_lens))
max_seq_len = int(torch.max(seq_lens))
if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_np = seq_lens.cpu().numpy()
prefix_lens_np = prefix_lens.cpu().numpy()
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_np[i] + position_ids_offsets_np[i],
seq_lens_np[i] + position_ids_offsets_np[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
)
other_kv_index = None
ret = cls(
model_runner=model_runner,
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
max_seq_len=max_seq_len,
req_pool_indices=req_pool_indices,
start_loc=start_loc,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
return_normalized_logprob=return_normalized_logprob,
other_kv_index=other_kv_index,
)
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
ret.use_flashinfer = "flashinfer" in model_runner.model_mode
if ret.use_flashinfer:
ret.init_flashinfer_args(tp_size)
return ret
class ModelRunner:
def __init__(
self,
model_config,
mem_fraction_static,
tp_rank,
tp_size,
nccl_port,
load_format="auto",
trust_remote_code=True,
model_mode: List[str] = (),
):
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
self.load_format = load_format
self.trust_remote_code = trust_remote_code
self.model_mode = model_mode
global global_model_mode
global_model_mode = model_mode
# Init torch distributed
torch.cuda.set_device(self.tp_rank)
torch.distributed.init_process_group(
backend="nccl",
world_size=self.tp_size,
rank=self.tp_rank,
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
# A small all_reduce for warmup.
if self.tp_size > 1:
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
self.load_model()
self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self):
"""See also vllm/model_executor/model_loader.py::get_model"""
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llava import LlavaLlamaForCausalLM
from sglang.srt.models.mixtral import MixtralForCausalLM
# Select model class
architectures = getattr(self.model_config.hf_config, "architectures", [])
model_class = None
for arch in architectures:
if arch == "LlamaForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "MistralForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "LlavaLlamaForCausalLM":
model_class = LlavaLlamaForCausalLM
break
if arch == "MixtralForCausalLM":
model_class = MixtralForCausalLM
break
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")
# Load weights
linear_method = None
with _set_default_torch_dtype(torch.float16):
with torch.device("cuda"):
hf_quant_config = getattr(
self.model_config.hf_config, "quantization_config", None
)
if hf_quant_config is not None:
# TODO: config quantization awq etc
quant_config = AWQConfig.from_config(hf_quant_config)
print(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
model = model_class(
config=self.model_config.hf_config, linear_method=linear_method
)
model.load_weights(
self.model_config.path,
cache_dir=None,
load_format=self.load_format,
revision=None,
)
self.model = model
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
head_dim = (
self.model_config.hidden_size // self.model_config.num_attention_heads
)
head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
max_num_token = int(rest_memory // cell_size)
return max_num_token
def init_memory_pool(self, total_gpu_memory):
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_token / self.model_config.context_len * 256),
self.model_config.context_len + 8,
)
self.token_to_kv_pool = TokenToKVPool(
self.max_total_num_token,
dtype=torch.float16,
head_num=self.model_config.num_key_value_heads // self.tp_size,
head_dim=self.model_config.hidden_size
// self.model_config.num_attention_heads,
layer_num=self.model_config.num_hidden_layers,
)
@torch.inference_mode()
def forward_prefill(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.PREFILL,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_extend(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_decode(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
0
]
@torch.inference_mode()
def forward_extend_multi_modal(
self,
input_ids,
pixel_values,
image_offsets,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_normalized_logprob,
):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob,
)
return self.model.forward(
input_ids,
input_metadata.positions,
input_metadata,
pixel_values,
image_offsets,
)
def forward(
self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False
):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
kwargs = {
"input_ids": batch.input_ids,
"pixel_values": batch.pixel_values,
"image_offsets": batch.image_offsets,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
}
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_extend_multi_modal(**kwargs)
else:
kwargs = {
"input_ids": batch.input_ids,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
}
if forward_mode == ForwardMode.DECODE:
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND:
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_extend(**kwargs)
elif forward_mode == ForwardMode.PREFILL:
kwargs["return_normalized_logprob"] = return_normalized_logprob
return self.forward_prefill(**kwargs)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")

View File

@@ -0,0 +1,220 @@
import heapq
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Tuple
import torch
class TreeNode:
def __init__(self):
self.children = defaultdict(TreeNode)
self.parent = None
self.value = None
self.ref_counter = 0
self.last_access_time = time.time()
def __lt__(self, other):
return self.last_access_time < other.last_access_time
def match(key, seq):
i = 0
for k, w in zip(key, seq):
if k != w:
break
i += 1
return i
class RadixCache:
def __init__(self, disable=False):
self.root_node = TreeNode()
self.root_node.value = []
self.root_node.ref_counter = 1
self.evictable_size_ = 0
self.disable = disable
##### Public API #####
def match_prefix(self, key):
if self.disable:
return [], self.root_node
value = []
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
if value:
value = torch.concat(value)
return value, last_node[0]
def insert(self, key, value=None):
if self.disable:
return len(key)
if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
def total_size(self):
return self._total_size_helper(self.root_node)
def evict(self, num_tokens, evict_callback):
if self.disable:
raise RuntimeError()
leaves = self._collect_leaves()
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node:
break
if x.ref_counter > 0:
continue
num_evicted += evict_callback(x.value)
self._delete_leaf(x)
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
def inc_ref_counter(self, node):
delta = 0
while node != self.root_node:
if node.ref_counter == 0:
self.evictable_size_ -= len(node.value)
delta -= len(node.value)
node.ref_counter += 1
node = node.parent
return delta
def dec_ref_counter(self, node):
delta = 0
while node != self.root_node:
if node.ref_counter == 1:
self.evictable_size_ += len(node.value)
delta += len(node.value)
node.ref_counter -= 1
node = node.parent
return delta
def evictable_size(self):
return self.evictable_size_
##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node):
node.last_access_time = time.time()
for c_key, child in node.children.items():
prefix_len = match(c_key, key)
if prefix_len != 0:
if prefix_len == len(key) and prefix_len != len(c_key):
new_node = self._split_node(c_key, child, prefix_len)
value.append(new_node.value)
last_node[0] = new_node
else:
value.append(child.value[:prefix_len])
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
break
def _split_node(self, key, child, split_len):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len:]: child}
new_node.parent = child.parent
new_node.ref_counter = child.ref_counter
new_node.value = child.value[:split_len]
child.parent = new_node
child.value = child.value[split_len:]
new_node.parent.children[key[:split_len]] = new_node
del new_node.parent.children[key]
return new_node
def _insert_helper(self, node, key, value):
node.last_access_time = time.time()
for c_key, child in node.children.items():
prefix_len = match(c_key, key)
if prefix_len == len(c_key):
if prefix_len == len(key):
return prefix_len
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)
if prefix_len:
new_node = self._split_node(c_key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.value = value
node.children[key] = new_node
self.evictable_size_ += len(value)
return 0
def _print_helper(self, node, indent):
for key, child in node.children.items():
print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
self._print_helper(child, indent=indent + 2)
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.evictable_size_ -= len(k)
def _total_size_helper(self, node):
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
return x
def _collect_leaves(self):
ret_list = []
def dfs_(cur_node):
if len(cur_node.children) == 0:
ret_list.append(cur_node)
for x in cur_node.children.values():
dfs_(x)
dfs_(self.root_node)
return ret_list
if __name__ == "__main__":
tree = RadixCache(disable=False)
tree.insert("Hello")
tree.insert("Hello")
tree.insert("Hello_L.A.!")
# tree.insert("Hello_world! Happy")
# tree.insert("I love you!")
tree.pretty_print()
# print(tree.match_prefix("I love you! aha"))
# def evict_callback(x):
# print("evict", x)
# return len(x)
# tree.evict(5, evict_callback)
# tree.evict(10, evict_callback)
# tree.pretty_print()

View File

@@ -0,0 +1,73 @@
import random
from collections import defaultdict
class Scheduler:
def __init__(
self,
schedule_heuristic,
max_running_seq,
max_prefill_num_token,
max_total_num_token,
tree_cache,
):
self.schedule_heuristic = schedule_heuristic
self.max_running_seq = max_running_seq
self.max_prefill_num_token = max_prefill_num_token
self.max_total_num_token = max_total_num_token
self.tree_cache = tree_cache
def new_token_estimation_ratio(self):
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm":
# longest prefix match
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
return forward_queue
elif self.schedule_heuristic == "random":
random.shuffle(forward_queue)
return forward_queue
elif self.schedule_heuristic == "fcfs":
return forward_queue
elif self.schedule_heuristic == "weight":
last_node_to_reqs = defaultdict(list)
for req in forward_queue:
last_node_to_reqs[req.last_node].append(req)
for node in last_node_to_reqs:
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
node_to_weight = defaultdict(int)
self._calc_weight_recursive(
self.tree_cache.root_node, last_node_to_reqs, node_to_weight
)
tmp_queue = []
self._get_weight_priority_recursive(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
)
assert len(tmp_queue) == len(forward_queue)
return tmp_queue
else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
node_to_weight[cur_node] = 1
if cur_node in last_node_to_reqs:
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
for child in cur_node.children.values():
self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child]
def _get_weight_priority_recursive(
self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
):
visit_list = [child for child in cur_node.children.values()]
visit_list.sort(key=lambda x: -node_to_wight[x])
# for node in visit_list:
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
for child in visit_list:
self._get_weight_priority_recursive(
child, node_to_wight, last_node_to_reqs, tmp_queue
)
tmp_queue.extend(last_node_to_reqs[cur_node])

View File

@@ -0,0 +1,219 @@
import asyncio
import concurrent.futures
import dataclasses
import os
from typing import List
import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class ReqState:
out_list: List
finished: bool
event: asyncio.Event
lock: asyncio.Lock
global global_processor
def init_global_processor(server_args: ServerArgs):
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
def get_pixel_values(image_data, processor=None):
try:
processor = processor or global_processor
image = load_image(image_data)
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
class TokenizerManager:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
self.model_path = server_args.model_path
self.hf_config = get_config(
self.model_path, trust_remote_code=server_args.trust_remote_code
)
self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor, initargs=(server_args,)
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.to_create_loop = True
self.rid_to_state = {} # Dict[str -> ReqState]
async def get_pixel_values(self, image_data):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, get_pixel_values, image_data
)
else:
return get_pixel_values(image_data, self.processor)
async def generate_request(self, obj: GenerateReqInput):
if self.to_create_loop:
await self.create_handle_loop()
is_single = isinstance(obj.text, str)
if is_single:
rid = obj.rid
input_ids = self.tokenizer.encode(obj.text)
sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data is None:
pixel_values, image_hash = None, None
else:
pixel_values, image_hash = await self.get_pixel_values(obj.image_data)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob,
normalized_logprob_start_len=obj.normalized_logprob_start_len,
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
lock = asyncio.Lock()
event = asyncio.Event()
state = ReqState([], False, event, lock)
self.rid_to_state[rid] = state
while True:
await event.wait()
yield state.out_list[-1]
state.out_list = []
if state.finished:
del self.rid_to_state[rid]
break
event.clear()
else:
assert obj.stream is False
bs = len(obj.text)
for i in range(bs):
rid = obj.rid[i]
input_ids = self.tokenizer.encode(obj.text[i])
sampling_params = SamplingParams(**obj.sampling_params[i])
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data[i] is None:
pixel_values, image_hash = None, None
else:
pixel_values, image_hash = await self.get_pixel_values(
obj.image_data[i]
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob[i],
normalized_logprob_start_len=obj.normalized_logprob_start_len[i],
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
lock = asyncio.Lock()
event = asyncio.Event()
state = ReqState([], False, event, lock)
self.rid_to_state[rid] = state
output_list = []
for i in range(bs):
rid = obj.rid[i]
state = self.rid_to_state[rid]
await state.event.wait()
output_list.append(state.out_list[-1])
assert state.finished
del self.rid_to_state[rid]
yield output_list
async def create_handle_loop(self):
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, BatchStrOut):
for i, rid in enumerate(recv_obj.rids):
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state = self.rid_to_state[rid]
state.out_list.append(out_dict)
state.finished = recv_obj.finished[i]
state.event.set()
else:
raise ValueError(f"Invalid object: {recv_obj}")

View File

@@ -0,0 +1,111 @@
"""Memory pool."""
import logging
import torch
logger = logging.getLogger(__name__)
class ReqToTokenPool:
def __init__(self, size, max_context_len):
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.can_use_mem_size = size
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
def alloc(self, need_size):
if need_size > self.can_use_mem_size:
return None
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
self.mem_state[select_index] = 0
self.can_use_mem_size -= need_size
return select_index.to(torch.int32)
def free(self, free_index):
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
else:
self.can_use_mem_size += free_index.shape[0]
self.mem_state[free_index] = 1
# if self.can_use_mem_size == len(self.mem_state):
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
def clear(self):
self.mem_state.fill_(1)
self.can_use_mem_size = len(self.mem_state)
class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
self.alloc_ct = 0
# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id):
return self.kv_data[layer_id][:, 0]
def get_value_buffer(self, layer_id):
return self.kv_data[layer_id][:, 1]
def alloc(self, need_size):
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if select_index.shape[0] < need_size:
return None
self.add_refs(select_index)
return select_index.to(torch.int32)
def alloc_contiguous(self, need_size):
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if empty_index.shape[0] < need_size:
return None
empty_size = len(empty_index)
loc_sum = (
empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)]
)
can_used_loc = empty_index[: empty_size - (need_size - 1)][
loc_sum == need_size - 1
]
if can_used_loc.shape[0] == 0:
return None
start_loc = can_used_loc[0].item()
select_index = torch.arange(start_loc, start_loc + need_size, device="cuda")
self.add_refs(select_index)
return select_index.to(torch.int32), start_loc, start_loc + need_size
def free(self, free_index):
return self.decrease_refs(free_index)
def used_size(self):
return len(torch.nonzero(self.mem_state).squeeze(1))
def available_size(self):
return torch.sum(self.mem_state == 0).item()
def add_refs(self, token_index: torch.Tensor):
self.alloc_ct += len(token_index)
self.mem_state[token_index] += 1
def decrease_refs(self, token_index: torch.Tensor):
self.alloc_ct -= len(token_index)
self.mem_state[token_index] -= 1
num_freed = torch.sum(self.mem_state[token_index] == 0)
# if self.alloc_ct == 0:
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
return num_freed
def clear(self):
self.mem_state.fill_(0)
self.alloc_ct = 0

View File

@@ -0,0 +1,27 @@
import os
from typing import Optional, Union
import torch
from sglang.srt.hf_transformers_utils import get_config, get_context_length
class ModelConfig:
def __init__(
self,
path: str,
trust_remote_code: bool = True,
revision: Optional[str] = None,
) -> None:
self.path = path
self.trust_remote_code = trust_remote_code
self.revision = revision
self.hf_config = get_config(self.path, trust_remote_code, revision)
# Unify the config keys for hf_config
self.context_len = get_context_length(self.hf_config)
self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
self.num_key_value_heads = self.hf_config.num_key_value_heads
self.num_attention_heads = self.hf_config.num_attention_heads
self.hidden_size = self.hf_config.hidden_size
self.num_hidden_layers = self.hf_config.num_hidden_layers
self.vocab_size = self.hf_config.vocab_size

View File

@@ -0,0 +1,316 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
linear_method=linear_method,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, linear_method=linear_method
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config, i, linear_method)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
) -> torch.Tensor:
if not skip_embed:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_ids
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = LlamaModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,213 @@
"""Inference-only LLaVa model compatible with HuggingFace weights."""
import json
import os
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM
from torch import nn
from transformers import CLIPImageProcessor, CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
class LlavaLlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlavaConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, linear_method)
def pad_input_ids(self, input_ids, pad_value):
pad_ids = pad_value * (
(self.image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[: self.image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
.cpu()
.numpy()
)
# FIXME: We need to substract the length of the system prompt
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
need_vision = need_vision & has_pixel
if need_vision.any():
pixel_values = torch.tensor(
np.array([pixel_values[i] for i in range(bs) if need_vision[i]]),
device=self.vision_tower.device,
)
image_outputs = self.vision_tower(
pixel_values, output_hidden_states=True
)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[
self.vision_feature_layer
]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
image_features = self.multi_modal_projector(selected_image_feature)
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
return self.language_model(
input_embeds, positions, input_metadata, skip_embed=True
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(
input_ids, positions, input_metadata, skip_embed=False
)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector
# TODO: support TP?
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
# FIXME: why projector weights read two times?
if "projector" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward()
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)

View File

@@ -0,0 +1,378 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Mixtral model."""
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn
from transformers import MixtralConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
)
self.w2 = ReplicatedLinear(
self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
)
self.w3 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList(
[
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
)
if idx in self.expert_indicies
else None
for idx in range(self.num_total_experts)
]
)
self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, linear_method=None
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states)
class MixtralAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method,
)
self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
# config.num_hidden_layers=16
self.layers = nn.ModuleList(
[
MixtralDecoderLayer(config, i, linear_method=linear_method)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
) -> torch.Tensor:
if not skip_embed:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_ids
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
skip_embed: bool = False,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=False
):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if "block_sparse_moe.experts." in name and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -0,0 +1,81 @@
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-6
class SamplingParams:
def __init__(
self,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
stop: Optional[Union[str, List[str]]] = None,
max_new_tokens: int = 16,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
dtype: Optional[str] = None,
regex: Optional[str] = None,
) -> None:
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.stop_strs = stop
self.max_new_tokens = max_new_tokens
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.dtype = dtype
self.regex = regex
# Process some special cases
if self.temperature < _SAMPLING_EPS:
self.temperature = 1.0
self.top_k = 1
if self.top_k == -1:
self.top_k = 1 << 30 # whole vocabulary
if self.dtype == "int":
self.stop_strs = [" ", "\n"]
def verify(self):
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}."
)
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
)
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError(
"frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}."
)
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError(
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
)
if self.max_new_tokens < 0:
raise ValueError(
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
)
def normalize(self, tokenizer):
# Process stop strings
if self.stop_strs is None:
self.stop_strs = []
self.stop_str_max_len = 0
else:
if isinstance(self.stop_strs, str):
self.stop_strs = [self.stop_strs]
stop_str_max_len = 0
for stop_str in self.stop_strs:
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
self.stop_str_max_len = stop_str_max_len

222
python/sglang/srt/server.py Normal file
View File

@@ -0,0 +1,222 @@
"""SRT: SGLang Runtime"""
import argparse
import asyncio
import dataclasses
import json
import multiprocessing as mp
import sys
import threading
import time
from typing import List, Optional
# Fix a Python bug
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
import psutil
import requests
import uvicorn
import uvloop
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.openai_protocol import CompletionRequest
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import alloc_usable_network_port
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI()
tokenizer_manager = None
@app.get("/get_model_info")
async def get_model_info():
result = {
"model_path": tokenizer_manager.model_path,
}
return result
@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
result_generator = tokenizer_manager.generate_request(obj)
if obj.stream:
async def stream_results():
async for out in result_generator:
yield (json.dumps(out) + "\0").encode("utf-8")
return StreamingResponse(stream_results(), media_type="text/event-stream")
else:
ret = await result_generator.__anext__()
return ret
@app.post("/v1/completions")
async def v1_completions(obj: CompletionRequest):
assert obj.n == 1
obj = GenerateReqInput(
text=obj.prompt,
sampling_params={
"temperature": obj.temperature,
"max_new_tokens": obj.max_tokens,
"stop": obj.stop,
},
)
ret = await generate_request(obj)
return {
"choices": [{"text": ret["text"]}],
}
def launch_server(server_args, pipe_finish_writer):
global tokenizer_manager
# Allocate ports
can_use_ports = alloc_usable_network_port(
num=4 + server_args.tp_size, used_list=(server_args.port,)
)
port_args = PortArgs(
tokenizer_port=can_use_ports[0],
router_port=can_use_ports[1],
detokenizer_port=can_use_ports[2],
nccl_port=can_use_ports[3],
model_rpc_ports=can_use_ports[4:],
)
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
proc_router = mp.Process(
target=start_router_process,
args=(
server_args,
port_args,
pipe_router_writer,
),
)
proc_router.start()
proc_detoken = mp.Process(
target=start_detokenizer_process,
args=(
server_args,
port_args,
pipe_detoken_writer,
),
)
proc_detoken.start()
# Wait for the model to finish loading
router_init_state = pipe_router_reader.recv()
detoken_init_state = pipe_detoken_reader.recv()
if router_init_state != "init ok" or detoken_init_state != "init ok":
proc_router.kill()
proc_detoken.kill()
print("router init state:", router_init_state)
print("detoken init state:", detoken_init_state)
sys.exit(1)
assert proc_router.is_alive() and proc_detoken.is_alive()
def launch_server():
# Launch api server
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_level=server_args.log_level,
timeout_keep_alive=5,
loop="uvloop",
)
t = threading.Thread(target=launch_server)
t.start()
if pipe_finish_writer:
url = server_args.url()
success = False
for i in range(60):
try:
res = requests.get(url + "/get_model_info", timeout=5)
success = True
break
except requests.exceptions.RequestException as e:
time.sleep(1)
if success:
pipe_finish_writer.send("init ok")
else:
pipe_finish_writer.send(str(e))
class Runtime:
def __init__(
self,
model_path: str,
tokenizer_path: Optional[str] = None,
load_format: str = "auto",
tokenizer_mode: str = "auto",
trust_remote_code: bool = True,
mem_fraction_static: float = 0.9,
tp_size: int = 1,
model_mode: List[str] = (),
schedule_heuristic: str = "lpm",
random_seed: int = 42,
log_level: str = "warning",
):
host = "127.0.0.1"
port = alloc_usable_network_port(1)[0]
server_args = ServerArgs(
model_path=model_path,
tokenizer_path=tokenizer_path,
host=host,
port=port,
load_format=load_format,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
mem_fraction_static=mem_fraction_static,
tp_size=tp_size,
model_mode=model_mode,
schedule_heuristic=schedule_heuristic,
random_seed=random_seed,
log_level=log_level,
)
self.url = server_args.url()
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(target=launch_server, args=(server_args, pipe_writer))
proc.start()
self.pid = proc.pid
init_state = pipe_reader.recv()
if init_state != "init ok":
self.shutdown()
raise RuntimeError("Launch failed")
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
if self.pid is not None:
parent = psutil.Process(self.pid)
children = parent.children(recursive=True)
for child in children:
child.kill()
psutil.wait_procs(children, timeout=5)
parent.kill()
parent.wait(timeout=5)
self.pid = None
def __del__(self):
self.shutdown()

View File

@@ -0,0 +1,138 @@
import argparse
import dataclasses
from typing import List, Optional
@dataclasses.dataclass
class ServerArgs:
model_path: str
tokenizer_path: Optional[str] = None
host: str = "127.0.0.1"
port: int = 30000
load_format: str = "auto"
tokenizer_mode: str = "auto"
trust_remote_code: bool = True
mem_fraction_static: float = 0.91
tp_size: int = 1
model_mode: List[str] = ()
schedule_heuristic: str = "lpm"
random_seed: int = 42
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--model-path",
type=str,
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True,
)
parser.add_argument(
"--tokenizer-path",
type=str,
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
parser.add_argument("--host", type=str, default=ServerArgs.host)
parser.add_argument("--port", type=int, default=ServerArgs.port)
parser.add_argument(
"--load-format",
type=str,
default=ServerArgs.load_format,
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format "
"is not available. "
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading. "
'"dummy" will initialize the weights with random values, '
"which is mainly for profiling.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default=ServerArgs.tokenizer_mode,
choices=["auto", "slow"],
help="Tokenizer mode. 'auto' will use the fast "
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--mem-fraction-static",
type=float,
default=ServerArgs.mem_fraction_static,
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)",
)
parser.add_argument(
"--tp-size",
type=int,
default=ServerArgs.tp_size,
help="Tensor parallelism degree.",
)
parser.add_argument(
"--model-mode",
type=str,
default=[],
nargs="+",
help="Model mode: [flashinfer, no-cache, aggressive-new-fill]",
)
parser.add_argument(
"--schedule-heuristic",
type=str,
default=ServerArgs.schedule_heuristic,
help="Schudule mode: [lpm, weight, random, fcfs]",
)
parser.add_argument(
"--random-seed",
type=int,
default=ServerArgs.random_seed,
help="Random seed.",
)
parser.add_argument(
"--log-level",
type=str,
default=ServerArgs.log_level,
help="Log level",
)
parser.add_argument(
"--disable-log-stats",
action="store_true",
help="Disable logging throughput stats.",
)
parser.add_argument(
"--log-stats-interval",
type=int,
default=ServerArgs.log_stats_interval,
help="Log stats interval in second.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def url(self):
return f"http://{self.host}:{self.port}"
@dataclasses.dataclass
class PortArgs:
tokenizer_port: int
router_port: int
detokenizer_port: int
nccl_port: int
model_rpc_ports: List[int]

217
python/sglang/srt/utils.py Normal file
View File

@@ -0,0 +1,217 @@
import base64
import os
import random
import socket
import sys
import time
import traceback
from io import BytesIO
import numpy as np
import requests
import torch
import torch.distributed as dist
is_show_cost_time = False
def mark_cost_time(func_name):
def inner_func(func):
def time_func(*args, **kwargs):
if dist.get_rank() in [0, 1] and is_show_cost_time:
torch.cuda.synchronize()
start_time = time.time()
ans = func(*args, **kwargs)
torch.cuda.synchronize()
print(func_name, "cost time:", (time.time() - start_time) * 1000)
return ans
else:
torch.cuda.synchronize()
ans = func(*args, **kwargs)
torch.cuda.synchronize()
return ans
return time_func
return inner_func
time_mark = {}
def mark_start(key):
torch.cuda.synchronize()
global time_mark
time_mark[key] = time.time()
return
def mark_end(key, print_min_cost=0.0):
torch.cuda.synchronize()
global time_mark
cost_time = (time.time() - time_mark[key]) * 1000
if cost_time > print_min_cost:
print(f"cost {key}:", cost_time)
def calculate_time(show=False, min_cost_ms=0.0):
def wrapper(func):
def inner_func(*args, **kwargs):
torch.cuda.synchronize()
if show:
start_time = time.time()
result = func(*args, **kwargs)
torch.cuda.synchronize()
if show:
cost_time = (time.time() - start_time) * 1000
if cost_time > min_cost_ms:
print(f"Function {func.__name__} took {cost_time} ms to run.")
return result
return inner_func
return wrapper
def set_random_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def alloc_usable_network_port(num, used_list=()):
port_list = []
for port in range(10000, 65536):
if port in used_list:
continue
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("", port))
port_list.append(port)
except socket.error:
pass
if len(port_list) == num:
return port_list
return None
def get_exception_traceback():
etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb))
return err_str
def get_int_token_logit_bias(tokenizer, vocab_size):
from transformers import LlamaTokenizer, LlamaTokenizerFast
logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size):
ss = tokenizer.decode(t_id).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
# else:
# print(ss, t_id)
return logit_bias
def wrap_kernel_launcher(kernel):
"""A faster launcher for triton kernels."""
import torch.distributed as dist
if dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
kernels = kernel.cache[rank].values()
kernel = next(iter(kernels))
# Different trition versions use different low-level names
if hasattr(kernel, "cu_function"):
kfunction = kernel.cu_function
else:
kfunction = kernel.function
if hasattr(kernel, "c_wrapper"):
run = kernel.c_wrapper
else:
run = kernel.run
add_cluster_dim = True
def ret_func(grid, num_warps, *args):
nonlocal add_cluster_dim
try:
if add_cluster_dim:
run(
grid[0],
grid[1],
grid[2],
num_warps,
1,
1,
1,
1,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
else:
run(
grid[0],
grid[1],
grid[2],
num_warps,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
except TypeError:
add_cluster_dim = not add_cluster_dim
ret_func(grid, num_warps, *args)
return ret_func
def is_multimodal_model(model):
if isinstance(model, str):
return "llava" in model
from sglang.srt.model_config import ModelConfig
if isinstance(model, ModelConfig):
return "llava" in model.path.lower()
raise Exception("unrecognized type")
def load_image(image_file):
from PIL import Image
image = None
if image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout)
image = Image.open(BytesIO(response.content))
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
image = Image.open(image_file)
elif image_file.startswith("data:"):
image_file = image_url.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_file)))
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))
return image

View File

@@ -0,0 +1,324 @@
"""
This file contains the SGL programs used for unit testing.
"""
import json
import re
import sglang as sgl
def test_few_shot_qa():
@sgl.function
def few_shot_qa(s, question):
s += "The following are questions with answers.\n\n"
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: What is the capital of Germany?\n"
s += "A: Berlin\n"
s += "Q: What is the capital of Italy?\n"
s += "A: Rome\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
ret = few_shot_qa.run(question="What is the capital of the United States?")
assert "washington" in ret["answer"].strip().lower(), f"answer: {ret['answer']}"
rets = few_shot_qa.run_batch(
[
{"question": "What is the capital of Japan?"},
{"question": "What is the capital of the United Kingdom?"},
{"question": "What is the capital city of China?"},
],
temperature=0.1,
)
answers = [x["answer"].strip().lower() for x in rets]
assert answers == ["tokyo", "london", "beijing"], f"answers: {answers}"
def test_mt_bench():
@sgl.function
def answer_mt_bench(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1"))
with s.user():
s += question_2
with s.assistant():
s += sgl.gen("answer_2")
question_1 = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
question_2 = (
"Rewrite your previous response. Start every sentence with the letter A."
)
ret = answer_mt_bench.run(
question_1=question_1, question_2=question_2, temperature=0.7, max_new_tokens=64
)
assert len(ret.messages()) in [4, 5]
def test_select(check_answer):
@sgl.function
def true_or_false(s, statement):
s += "Determine whether the statement below is True, False, or Unknown.\n"
s += "Statement: The capital of France is Pairs.\n"
s += "Answer: True\n"
s += "Statement: " + statement + "\n"
s += "Answer:" + sgl.select("answer", ["True", "False", "Unknown"])
ret = true_or_false.run(
statement="The capital of Germany is Berlin.",
)
if check_answer:
assert ret["answer"] == "True", ret.text
else:
assert ret["answer"] in ["True", "False", "Unknown"]
ret = true_or_false.run(
statement="The capital of Canada is Tokyo.",
)
if check_answer:
assert ret["answer"] == "False", ret.text
else:
assert ret["answer"] in ["True", "False", "Unknown"]
ret = true_or_false.run(
statement="Purple is a better color than green.",
)
if check_answer:
assert ret["answer"] == "Unknown", ret.text
else:
assert ret["answer"] in ["True", "False", "Unknown"]
def test_decode_int():
@sgl.function
def decode_int(s):
s += "The number of hours in a day is " + sgl.gen_int("hours") + "\n"
s += "The number of days in a year is " + sgl.gen_int("days") + "\n"
ret = decode_int.run(temperature=0.1)
assert int(ret["hours"]) == 24, ret.text
assert int(ret["days"]) == 365, ret.text
def test_decode_json():
@sgl.function
def decode_json(s):
s += "Generate a JSON object to describe the basic information of a city.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen_string() + ",\n"
s += ' "population": ' + sgl.gen_int() + ",\n"
s += ' "area": ' + sgl.gen(dtype=int) + ",\n"
s += ' "country": ' + sgl.gen_string() + ",\n"
s += ' "timezone": ' + sgl.gen(dtype=str) + "\n"
s += "}"
ret = decode_json.run()
js_obj = json.loads(ret["json_output"])
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_expert_answer():
@sgl.function
def expert_answer(s, question):
s += "Question: " + question + "\n"
s += (
"A good person to answer this question is"
+ sgl.gen("expert", stop=[".", "\n"])
+ ".\n"
)
s += (
"For example,"
+ s["expert"]
+ " would answer that "
+ sgl.gen("answer", stop=".")
+ "."
)
ret = expert_answer.run(question="What is the capital of France?", temperature=0.1)
assert "paris" in ret.text().lower()
def test_tool_use():
def calculate(expression):
return f"{eval(expression)}"
@sgl.function
def tool_use(s, lhs, rhs):
s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n"
s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n"
s += "Question: What is the product of " + lhs + " and " + rhs + "?\n"
s += (
"Answer: The answer is calculate("
+ sgl.gen("expression", stop=")")
+ ") = "
)
with s.var_scope("answer"):
s += calculate(s["expression"])
lhs, rhs = 257, 983
ret = tool_use(lhs=lhs, rhs=rhs, temperature=0)
assert int(ret["answer"]) == lhs * rhs
def test_react():
@sgl.function
def react(s, question):
s += """
Question: Which country does the founder of Microsoft live in?
Thought 1: I need to search for the founder of Microsoft.
Action 1: Search [Founder of Microsoft].
Observation 1: The founder of Microsoft is Bill Gates.
Thought 2: I need to search for the country where Bill Gates lives in.
Action 2: Search [Where does Bill Gates live].
Observation 2: Bill Gates lives in the United States.
Thought 3: The answer is the United States.
Action 3: Finish [United States].\n
"""
s += "Question: " + question + "\n"
for i in range(1, 5):
s += f"Thought {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
s += f"Action {i}: " + sgl.select(f"action_{i}", ["Search", "Finish"])
if s[f"action_{i}"] == "Search":
s += " [" + sgl.gen(stop="]") + "].\n"
s += f"Observation {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
else:
s += " [" + sgl.gen("answer", stop="]") + "].\n"
break
ret = react.run(
question="What country does the creator of Linux live in?",
temperature=0.1,
)
answer = ret["answer"].lower()
assert "finland" in answer or "states" in answer
def test_parallel_decoding():
max_tokens = 64
number = 5
@sgl.function
def parallel_decoding(s, topic):
s += "Act as a helpful assistant.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
# Generate skeleton
for i in range(1, 1 + number):
s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
# Generate detailed tips
forks = s.fork(number)
for i in range(number):
forks[
i
] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"])
forks.join()
# Concatenate tips and summarize
s += "Here are these tips with detailed explanation:\n"
for i in range(number):
s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
def test_parallel_encoding(check_answer=True):
max_tokens = 64
@sgl.function
def parallel_encoding(s, question, context_0, context_1, context_2):
s += "USER: I will ask a question based on some statements.\n"
s += "ASSISTANT: Sure. I will give the answer.\n"
s += "USER: Please memorize these statements.\n"
contexts = [context_0, context_1, context_2]
forks = s.fork(len(contexts))
forks += lambda i: f"Statement {i}: " + contexts[i] + "\n"
forks.join(mode="concate_and_append")
s += "Now, please answer the following question. " "Do not list options."
s += "\nQuestion: " + question + "\n"
s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens)
ret = parallel_encoding.run(
question="Who is the father of Julian?",
context_0="Ethan is the father of Liam.",
context_1="Noah is the father of Julian.",
context_2="Oliver is the father of Carlos.",
temperature=0,
)
answer = ret["answer"]
if check_answer:
assert "Noah" in answer
def test_image_qa():
@sgl.function
def image_qa(s, question):
s += sgl.user(sgl.image("image.png") + question)
s += sgl.assistant(sgl.gen("answer"))
state = image_qa.run(
question="Please describe this image in simple words.",
temperature=0,
max_new_tokens=64,
)
assert "taxi" in state.messages()[-1]["content"]
def test_stream():
@sgl.function
def qa(s, question):
s += sgl.user(question)
s += sgl.assistant(sgl.gen("answer"))
ret = qa(
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
stream=True,
)
out = ""
for chunk in ret.text_iter():
out += chunk
ret = qa(
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
stream=True,
)
out = ""
for chunk in ret.text_iter("answer"):
out += chunk
def test_regex():
regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@sgl.function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
temperature=0,
regex=regex,
)
state = regex_gen.run()
answer = state["answer"]
assert re.match(regex, answer)

View File

@@ -0,0 +1,141 @@
"""Common utilities for testing and benchmarking"""
import numpy as np
import requests
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config
def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
data = {
"inputs": prompt,
"parameters": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"stop_sequences": stop,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
pred = res.json()["generated_text"][0]
return pred
def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
if n == 1:
pred = res.json()["text"][0][len(prompt) :]
else:
pred = [x[len(prompt) :] for x in res.json()["text"]]
return pred
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
data = {
"text": prompt,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"stop": stop,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
obj = res.json()
pred = obj["text"]
return pred
def call_select_lightllm(context, choices, url):
scores = []
for i in range(len(choices)):
data = {
"inputs": context + choices[i],
"parameters": {
"max_new_tokens": 1,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
scores.append(0)
return np.argmax(scores)
def call_select_vllm(context, choices, url):
scores = []
for i in range(len(choices)):
data = {
"prompt": context + choices[i],
"max_tokens": 1,
"prompt_logprobs": 1,
}
res = requests.post(url, json=data)
assert res.status_code == 200
scores.append(res.json()["prompt_score"])
return np.argmax(scores)
"""
Modify vllm/entrypoints/api_server.py
if final_output.prompt_logprobs is not None:
score = np.mean([prob[t_id] for t_id, prob in zip(final_output.prompt_token_ids[1:], final_output.prompt_logprobs[1:])])
ret["prompt_score"] = score
"""
def add_common_other_args_and_parse(parser):
parser.add_argument("--parallel", type=int, default=96)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument(
"--backend",
type=str,
required=True,
choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"],
)
parser.add_argument(
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
parser.add_argument("--result-file", type=str, default="result.jsonl")
args = parser.parse_args()
if args.port is None:
default_port = {
"vllm": 21000,
"lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000,
}
args.port = default_port.get(args.backend, None)
return args
def add_common_sglang_args_and_parse(parser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--result-file", type=str, default="result.jsonl")
args = parser.parse_args()
return args
def select_sglang_backend(args):
if args.backend.startswith("srt"):
if args.backend == "srt-no-parallel":
global_config.enable_parallel_decoding = False
global_config.enable_parallel_encoding = False
backend = RuntimeEndpoint(f"{args.host}:{args.port}")
elif args.backend.startswith("gpt"):
backend = OpenAI(args.backend)
else:
raise ValueError(f"Invalid backend: {args.backend}")
return backend

179
python/sglang/utils.py Normal file
View File

@@ -0,0 +1,179 @@
"""Common utilities."""
import base64
import json
import threading
import urllib.request
from io import BytesIO
from json import dumps
import requests
def get_available_gpu_memory(gpu_id, distributed=True):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
import torch
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:
print(
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device("cuda", gpu_id)
)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item()
return free_gpu_memory / (1 << 30)
def is_same_type(values):
"""Return whether the elements in values are of the same type."""
if len(values) <= 1:
return True
else:
t = type(values[0])
return all(isinstance(v, t) for v in values[1:])
def read_jsonl(filename: str):
"""Read a JSONL file."""
rets = []
with open(filename) as fin:
for line in fin:
if line.startswith("#"):
continue
rets.append(json.loads(line))
return rets
def dump_state_text(filename, states, mode="w"):
"""Dump program state in a text file."""
from sglang.lang.interpreter import ProgramState
with open(filename, mode) as fout:
for i, s in enumerate(states):
if isinstance(s, str):
pass
elif isinstance(s, ProgramState):
s = s.text().strip()
else:
s = str(s)
fout.write(
"=" * 40 + f" {i} " + "=" * 40 + "\n" + s + "\n" + "=" * 80 + "\n\n"
)
class HttpResponse:
def __init__(self, resp):
self.resp = resp
def json(self):
return json.loads(self.resp.read())
@property
def status_code(self):
return self.resp.status
def http_request(url, json=None, stream=False):
"""A faster version of requests.post with low-level urllib API."""
if stream:
return requests.post(url, json=json, stream=True)
else:
req = urllib.request.Request(url)
req.add_header("Content-Type", "application/json; charset=utf-8")
if json is None:
data = None
else:
data = bytes(dumps(json), encoding="utf-8")
resp = urllib.request.urlopen(req, data=data)
return HttpResponse(resp)
def encode_image_base64(image_path):
"""Encode an image in base64."""
if isinstance(image_path, str):
with open(image_path, "rb") as image_file:
data = image_file.read()
return base64.b64encode(data).decode("utf-8")
elif isinstance(image_path, bytes):
return base64.b64encode(image_path).decode("utf-8")
else:
# image_path is PIL.WebPImagePlugin.WebPImageFile
image = image_path
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
def find_printable_text(text):
"""Returns the longest printable substring of text that contains only entire words."""
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
return text
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
return text
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
return text[: text.rfind(" ") + 1]
def run_with_timeout(func, args=(), kwargs=None, timeout=None):
"""Run a function with timeout."""
ret_value = []
def _target_func():
ret_value.append(func(*args, **(kwargs or {})))
t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()
if not ret_value:
raise RuntimeError()
return ret_value[0]