release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
31
python/pyproject.toml
Normal file
31
python/pyproject.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sglang"
|
||||
version = "0.1.0"
|
||||
description = "A structured generation langauge for LLMs."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE"}
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
]
|
||||
dependencies = [
|
||||
"requests",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
||||
"interegular", "lark"]
|
||||
openai = ["openai>=1.0"]
|
||||
anthropic = ["anthropic"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
|
||||
[tool.wheel]
|
||||
exclude = ["assets*", "benchmark*", "docs*", "dist*", "playground*", "scripts*", "tests*"]
|
||||
2
python/sglang/__init__.py
Normal file
2
python/sglang/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from sglang.api import *
|
||||
from sglang.global_config import global_config
|
||||
161
python/sglang/api.py
Normal file
161
python/sglang/api.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Public API"""
|
||||
import re
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.backend.anthropic import Anthropic
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.backend.openai import OpenAI
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.ir import (
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglImage,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
)
|
||||
from sglang.srt.server import Runtime
|
||||
|
||||
|
||||
def function(func: Callable):
|
||||
return SglFunction(func)
|
||||
|
||||
|
||||
def set_default_backend(backend: BaseBackend):
|
||||
global_config.default_backend = backend
|
||||
|
||||
|
||||
def gen(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
dtype: Optional[type] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
regex: Optional[str] = None,
|
||||
):
|
||||
if choices:
|
||||
return SglSelect(name, choices, temperature)
|
||||
|
||||
# check regex is valid
|
||||
if regex is not None:
|
||||
try:
|
||||
re.compile(regex)
|
||||
except re.error as e:
|
||||
raise e
|
||||
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
dtype,
|
||||
regex,
|
||||
)
|
||||
|
||||
|
||||
def gen_int(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
int,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def gen_string(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def image(expr: SglExpr):
|
||||
return SglImage(expr)
|
||||
|
||||
|
||||
def select(
|
||||
name: Optional[str] = None,
|
||||
choices: List[str] = None,
|
||||
temperature: float = 0.0,
|
||||
):
|
||||
assert choices is not None
|
||||
return SglSelect(name, choices, temperature)
|
||||
|
||||
|
||||
def _role_common(name: str, expr: Optional[SglExpr] = None):
|
||||
if expr is None:
|
||||
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
|
||||
else:
|
||||
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
||||
|
||||
|
||||
def system(expr: Optional[SglExpr] = None):
|
||||
return _role_common("system", expr)
|
||||
|
||||
|
||||
def user(expr: Optional[SglExpr] = None):
|
||||
return _role_common("user", expr)
|
||||
|
||||
|
||||
def assistant(expr: Optional[SglExpr] = None):
|
||||
return _role_common("assistant", expr)
|
||||
|
||||
|
||||
def user_begin():
|
||||
return SglRoleBegin("user")
|
||||
|
||||
|
||||
def user_end():
|
||||
return SglRoleEnd("user")
|
||||
|
||||
|
||||
def assistant_begin():
|
||||
return SglRoleBegin("assistant")
|
||||
|
||||
|
||||
def assistant_end():
|
||||
return SglRoleEnd("assistant")
|
||||
0
python/sglang/backend/__init__.py
Normal file
0
python/sglang/backend/__init__.py
Normal file
57
python/sglang/backend/anthropic.py
Normal file
57
python/sglang/backend/anthropic.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as e:
|
||||
anthropic = e
|
||||
|
||||
|
||||
class Anthropic(BaseBackend):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(anthropic, Exception):
|
||||
raise anthropic
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("claude")
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
prompt = s.text_
|
||||
ret = anthropic.Anthropic().completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
comp = ret.completion
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
prompt = s.text_
|
||||
generator = anthropic.Anthropic().completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
|
||||
for ret in generator:
|
||||
yield ret.completion, {}
|
||||
74
python/sglang/backend/base_backend.py
Normal file
74
python/sglang/backend/base_backend.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def __init__(self) -> None:
|
||||
self.support_concate_and_append = False
|
||||
self.chat_template = get_chat_template("default")
|
||||
|
||||
def get_model_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
pass
|
||||
|
||||
def uncache_prefix(self, rid: str):
|
||||
pass
|
||||
|
||||
def end_request(self, rid: Union[str, List[str]]):
|
||||
pass
|
||||
|
||||
def begin_program(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
|
||||
pass
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def fork_program(
|
||||
self,
|
||||
src: StreamExecutor,
|
||||
dst: List[StreamExecutor],
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
349
python/sglang/backend/huggingface.py
Normal file
349
python/sglang/backend/huggingface.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import functools
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import ProgramState
|
||||
from sglang.utils import get_available_gpu_memory
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from transformersgl.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
class StopReason(Enum):
|
||||
EOS_TOKEN = auto()
|
||||
STOP_STR = auto()
|
||||
LENGTH = auto()
|
||||
|
||||
|
||||
def load_model(
|
||||
model_name: str,
|
||||
device,
|
||||
num_gpus,
|
||||
max_gpu_memory,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
):
|
||||
model_kwargs = model_kwargs or {}
|
||||
tokenizer_kwargs = tokenizer_kwargs or {}
|
||||
|
||||
if device == "cuda":
|
||||
model_kwargs["torch_dtype"] = torch.float16
|
||||
if num_gpus != 1:
|
||||
model_kwargs["device_map"] = "auto"
|
||||
if max_gpu_memory is None:
|
||||
model_kwargs[
|
||||
"device_map"
|
||||
] = "sequential" # This is important for not the same VRAM sizes
|
||||
available_gpu_memory = [
|
||||
get_available_gpu_memory(i, False) for i in range(num_gpus)
|
||||
]
|
||||
model_kwargs["max_memory"] = {
|
||||
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
||||
for i in range(num_gpus)
|
||||
}
|
||||
else:
|
||||
model_kwargs["max_memory"] = {
|
||||
i: max_gpu_memory for i in range(num_gpus)
|
||||
}
|
||||
elif device == "cpu":
|
||||
model_kwargs["torch_dtype"] = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Invalid device: {device}")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, low_cpu_mem_usage=True, **model_kwargs
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
|
||||
|
||||
if num_gpus == 1:
|
||||
model.to(device).eval()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_logits_processor(
|
||||
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
||||
) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
||||
if temperature >= 1e-5 and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if repetition_penalty > 1.0:
|
||||
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||
if 1e-8 <= top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
if top_k > 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
return processor_list
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_token_healing_mask(tokenizer, prompt_last_token):
|
||||
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
|
||||
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
||||
for s, t_id in tokenizer.get_vocab().items():
|
||||
if not s.startswith(last_str):
|
||||
disallowed[t_id] = 1
|
||||
return disallowed
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_int_token_mask(tokenizer):
|
||||
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
||||
for s, t_id in tokenizer.get_vocab().items():
|
||||
s = s.replace("▁", "").strip()
|
||||
if not (s.isdigit() or len(s) == 0 or s == ","):
|
||||
disallowed[t_id] = 1
|
||||
disallowed[tokenizer.eos_token_id] = 0
|
||||
return disallowed
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
stop: List[str],
|
||||
temperature,
|
||||
top_p,
|
||||
token_healing,
|
||||
logit_mask=None,
|
||||
):
|
||||
logits_processor = prepare_logits_processor(
|
||||
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
|
||||
)
|
||||
device = model.device
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
output_ids = list(input_ids)
|
||||
prompt_len = len(prompt)
|
||||
|
||||
# Resolve stop
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
# Token healing
|
||||
token_healing = token_healing and len(input_ids) > 0
|
||||
if token_healing:
|
||||
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
|
||||
del output_ids[-1]
|
||||
|
||||
# Generate
|
||||
past_key_values = None
|
||||
stop_reason = None
|
||||
for i in range(max_new_tokens):
|
||||
# Forward
|
||||
if i == 0: # prefill
|
||||
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
|
||||
else: # decoding
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
# Logit mask
|
||||
if token_healing and i == 0:
|
||||
logits[0, -1, token_healing_mask] = -1e4
|
||||
if logit_mask is not None:
|
||||
logits[0, -1, logit_mask] = -1e4
|
||||
|
||||
# Sample next token
|
||||
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
|
||||
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
output_ids.append(token)
|
||||
|
||||
# Stop condition
|
||||
if token in stop_token_ids:
|
||||
stop_reason = StopReason.EOS_TOKEN
|
||||
break
|
||||
|
||||
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
for stop_str in stop:
|
||||
pos = output_str[prompt_len:].find(stop_str)
|
||||
if pos != -1:
|
||||
stop_reason = StopReason.STOP_STR
|
||||
output_str = output_str[: prompt_len + pos]
|
||||
break
|
||||
|
||||
if stop_reason:
|
||||
break
|
||||
|
||||
return output_str[prompt_len:]
|
||||
|
||||
|
||||
class HuggingFaceTransformers(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
device="cuda",
|
||||
num_gpus=1,
|
||||
max_gpu_memory=None,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
|
||||
self.model, self.tokenizer = load_model(
|
||||
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
|
||||
)
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(model_name)
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
pass
|
||||
|
||||
def uncache_prefix(self, rid: str):
|
||||
pass
|
||||
|
||||
def end_request(self, rid: str):
|
||||
pass
|
||||
|
||||
def begin_program(self, s: ProgramState):
|
||||
pass
|
||||
|
||||
def end_program(self, s: ProgramState):
|
||||
pass
|
||||
|
||||
def fill(self, s: ProgramState, text: str):
|
||||
return False
|
||||
|
||||
def generate_internal(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
if dtype is None:
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=True,
|
||||
)
|
||||
elif dtype in [str, "str", "string"]:
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt + '"',
|
||||
max_new_tokens=max_tokens,
|
||||
stop=['"'],
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=False,
|
||||
)
|
||||
comp = '"' + comp + '"'
|
||||
elif dtype in [int, "int"]:
|
||||
logit_mask = get_int_token_mask(self.tokenizer)
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
stop=stop + [" ", ","],
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=False,
|
||||
logit_mask=logit_mask,
|
||||
)
|
||||
return comp
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: ProgramState,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
prompt = s.text
|
||||
comp = self.generate_internal(
|
||||
prompt, max_tokens, stop, temperature, top_p, dtype
|
||||
)
|
||||
return comp
|
||||
|
||||
def parallel_generate(
|
||||
self,
|
||||
s: ProgramState,
|
||||
prefixes: List[str],
|
||||
join_func: Callable,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
prompt = s.text
|
||||
parallel_prompts = [prompt + prefix for prefix in prefixes]
|
||||
|
||||
comps = []
|
||||
for i in range(len(parallel_prompts)):
|
||||
comps.append(
|
||||
self.generate_internal(
|
||||
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
|
||||
)
|
||||
)
|
||||
|
||||
joined = join_func([p + c for p, c in zip(prefixes, comps)])
|
||||
return joined, comps
|
||||
|
||||
@torch.inference_mode()
|
||||
def select(
|
||||
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
|
||||
):
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
prompt = s.text
|
||||
|
||||
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
|
||||
prompt_choices = [prompt + choice for choice in choices]
|
||||
|
||||
scores = []
|
||||
for i in range(len(choices)):
|
||||
choice_ids = self.tokenizer.encode(
|
||||
prompt_choices[i], return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
logits = self.model(choice_ids).logits
|
||||
|
||||
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
|
||||
|
||||
logprobs = torch.log(torch.softmax(logits, dim=-1))
|
||||
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
|
||||
idx2 = choice_ids[0, 1:]
|
||||
selected_logprobs = logprobs[0, idx1, idx2]
|
||||
score = selected_logprobs.mean().item()
|
||||
scores.append(score)
|
||||
|
||||
decision = choices[np.argmax(scores)]
|
||||
return decision, scores
|
||||
241
python/sglang/backend/openai.py
Normal file
241
python/sglang/backend/openai.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
|
||||
try:
|
||||
import openai
|
||||
import tiktoken
|
||||
except ImportError as e:
|
||||
openai = tiktoken = e
|
||||
|
||||
|
||||
def create_logit_bias_int(tokenizer):
|
||||
"""Get logit bias for integer numbers."""
|
||||
int_token_ids = []
|
||||
|
||||
tokens = tokenizer._mergeable_ranks
|
||||
for token, token_id in tokens.items():
|
||||
s = tokenizer.decode([token_id])
|
||||
if all([c.isdigit() for c in s]) or s in [" "]:
|
||||
int_token_ids.append(token_id)
|
||||
if len(int_token_ids) >= 300: # OpenAI API limit
|
||||
break
|
||||
special_tokens = tokenizer._special_tokens
|
||||
mask = {t: 100 for t in int_token_ids[:299]}
|
||||
mask[special_tokens["<|endoftext|>"]] = 100
|
||||
return mask
|
||||
|
||||
|
||||
CHAT_MODEL_NAMES = [
|
||||
# GPT-4
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314",
|
||||
# GPT-3.5
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
]
|
||||
|
||||
|
||||
class OpenAI(BaseBackend):
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.client = openai.OpenAI(*args, **kwargs)
|
||||
|
||||
if isinstance(openai, Exception):
|
||||
raise e
|
||||
|
||||
self.model_name = model_name
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
||||
|
||||
if model_name in CHAT_MODEL_NAMES:
|
||||
self.is_chat_model = True
|
||||
else:
|
||||
self.is_chat_model = False
|
||||
|
||||
self.chat_template = get_chat_template("default")
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
assert s.text_.endswith("ASSISTANT:")
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
elif sampling_params.dtype in [str, "str", "string"]:
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_ + '"',
|
||||
stop='"',
|
||||
**kwargs,
|
||||
)
|
||||
comp = '"' + comp + '"'
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_,
|
||||
logit_bias=self.logit_bias_int,
|
||||
stop=[" "],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
assert s.text_.endswith("ASSISTANT:")
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
generator = openai_completion_stream(
|
||||
client=self.client,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
return generator
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
n_choices = len(choices)
|
||||
token_ids = [self.tokenizer.encode(x) for x in choices]
|
||||
scores = [0] * n_choices
|
||||
valid = [len(x) > 0 for x in token_ids]
|
||||
prompt_tokens = self.tokenizer.encode(s.text_)
|
||||
|
||||
max_len = max([len(x) for x in token_ids])
|
||||
for step in range(max_len):
|
||||
# Build logit bias
|
||||
logit_bias = {}
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
logit_bias[token_ids[i][step]] = 100
|
||||
|
||||
# Call API
|
||||
ret = self.client.completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt_tokens,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
ret_str = ret.choices[0].text
|
||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||
|
||||
# TODO:
|
||||
# 1. return logits as the scores
|
||||
# 2. compute logits of the full choice
|
||||
# 3. consider chunk-based decoding
|
||||
|
||||
# Update valid
|
||||
hit = False
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
if step == len(token_ids[i]) - 1:
|
||||
valid[i] = False
|
||||
|
||||
if ret_token == token_ids[i][step]:
|
||||
scores[i] += 1
|
||||
hit = True
|
||||
else:
|
||||
valid[i] = False
|
||||
assert hit
|
||||
|
||||
if np.sum(valid) <= 1:
|
||||
break
|
||||
|
||||
prompt_tokens.append(ret_token)
|
||||
|
||||
decision = choices[np.argmax(scores)]
|
||||
return decision, scores
|
||||
|
||||
|
||||
def openai_completion(client, is_chat=None, prompt=None, **kwargs):
|
||||
try:
|
||||
if is_chat:
|
||||
if kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
||||
comp = ret.choices[0].message.content
|
||||
else:
|
||||
ret = client.completions.create(prompt=prompt, **kwargs)
|
||||
if isinstance(prompt, (list, tuple)):
|
||||
comp = [c.text for c in ret.choices]
|
||||
else:
|
||||
comp = ret.choices[0].text
|
||||
except openai.OpenAIError as e:
|
||||
print(f"OpenAI Error: {e}")
|
||||
raise e
|
||||
|
||||
return comp
|
||||
|
||||
|
||||
def openai_completion_stream(client, is_chat=None, prompt=None, **kwargs):
|
||||
try:
|
||||
if is_chat:
|
||||
generator = client.chat.completions.create(
|
||||
messages=prompt, stream=True, **kwargs
|
||||
)
|
||||
for ret in generator:
|
||||
content = ret.choices[0].delta.content
|
||||
yield content or "", {}
|
||||
else:
|
||||
generator = client.completions.create(prompt=prompt, stream=True, **kwargs)
|
||||
for ret in generator:
|
||||
content = ret.choices[0].text
|
||||
yield content or "", {}
|
||||
except openai.OpenAIError as e:
|
||||
print(f"OpenAI Error: {e}")
|
||||
raise e
|
||||
171
python/sglang/backend/runtime_endpoint.py
Normal file
171
python/sglang/backend/runtime_endpoint.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import json
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams, SglArgument
|
||||
from sglang.utils import encode_image_base64, find_printable_text, http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
def __init__(self, base_url):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
|
||||
self.base_url = base_url
|
||||
|
||||
res = http_request(self.base_url + "/get_model_info")
|
||||
assert res.status_code == 200
|
||||
self.model_info = res.json()
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data)
|
||||
assert res.status_code == 200
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(self.base_url + "/generate", json=data)
|
||||
obj = res.json()
|
||||
comp = obj["text"]
|
||||
return comp, obj["meta_info"]
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
data["stream"] = True
|
||||
self._add_images(s, data)
|
||||
|
||||
response = http_request(self.base_url + "/generate", json=data, stream=True)
|
||||
pos = 0
|
||||
|
||||
incomplete_text = ""
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
text = find_printable_text(data["text"][pos:])
|
||||
meta_info = data["meta_info"]
|
||||
pos += len(text)
|
||||
incomplete_text = data["text"][pos:]
|
||||
yield text, meta_info
|
||||
|
||||
if len(incomplete_text) > 0:
|
||||
yield incomplete_text, meta_info
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
assert temperature <= 1e-5
|
||||
|
||||
# Cache common prefix
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data)
|
||||
assert res.status_code == 200
|
||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||
|
||||
# Compute logprob
|
||||
data = {
|
||||
"text": [s.text_ + c for c in choices],
|
||||
"sampling_params": {"max_new_tokens": 0},
|
||||
"return_normalized_logprob": True,
|
||||
"normalized_logprob_start_len": prompt_len,
|
||||
}
|
||||
self._add_images(s, data)
|
||||
res = http_request(self.base_url + "/generate", json=data)
|
||||
assert res.status_code == 200
|
||||
logps = [r["meta_info"]["normalized_logprob"] for r in res.json()]
|
||||
|
||||
decision = choices[np.argmax(logps)]
|
||||
return decision, logps
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
res = http_request(
|
||||
self.base_url + "/concate_and_append_request",
|
||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
def _add_images(self, s: StreamExecutor, data):
|
||||
if s.images_:
|
||||
assert len(s.images_) == 1, "Only support one image."
|
||||
data["image_data"] = s.images_[0][1]
|
||||
190
python/sglang/backend/tgi.py
Normal file
190
python/sglang/backend/tgi.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class TGI(BaseBackend):
|
||||
def __init__(self, base_url):
|
||||
super().__init__()
|
||||
|
||||
self.base_url = base_url
|
||||
|
||||
res = http_request(self.base_url + "/info")
|
||||
assert res.status_code == 200
|
||||
self.model_info = res.json()
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_id"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_id"]
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
@staticmethod
|
||||
def adapt_params(max_tokens, stop, sampling_params, **override_params):
|
||||
temperature = sampling_params.temperature
|
||||
do_sample = True
|
||||
if temperature == 0:
|
||||
do_sample = False
|
||||
temperature = None
|
||||
|
||||
if stop is None:
|
||||
stop = []
|
||||
elif isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
top_p = sampling_params.top_p
|
||||
if top_p == 0:
|
||||
top_p = 0.001
|
||||
if top_p == 1:
|
||||
top_p = 0.999
|
||||
|
||||
top_k = sampling_params.top_k
|
||||
if top_k == -1:
|
||||
top_k = None
|
||||
|
||||
params = {
|
||||
"decoder_input_details": False,
|
||||
"details": False,
|
||||
"do_sample": do_sample,
|
||||
"max_new_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"return_full_text": False,
|
||||
}
|
||||
params.update(override_params)
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(text):
|
||||
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
||||
for word in words:
|
||||
try:
|
||||
int(word)
|
||||
return word
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _extract_choice(choices, text):
|
||||
# FIXME: Current only support the case where the choices are single words.
|
||||
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
||||
for word in words:
|
||||
if word in choices:
|
||||
return word
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _truncate_to_stop(text, stop):
|
||||
# The stop sequence may not be a single token. In this case TGI will generate
|
||||
# too many tokens so we need to truncate the output.
|
||||
if stop:
|
||||
stop = [stop] if isinstance(stop, str) else stop
|
||||
for stop_seq in stop:
|
||||
pos = text.find(stop_seq)
|
||||
if pos != -1:
|
||||
return text[:pos]
|
||||
return text
|
||||
|
||||
def _make_request(self, params):
|
||||
res = http_request(self.base_url + "/generate", json=params)
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error from TGI backend: {res.text}")
|
||||
return res.json()
|
||||
|
||||
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
|
||||
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
|
||||
failed = []
|
||||
while retry > 0:
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
)
|
||||
text = res_json["generated_text"]
|
||||
try:
|
||||
return extract_fn(text)
|
||||
except ValueError:
|
||||
retry -= 1
|
||||
failed.append(text)
|
||||
|
||||
msg = "=" * 20 + "\n"
|
||||
msg += f"Prompt:\n{prompt}\n"
|
||||
msg += "=" * 20 + "\n"
|
||||
for i, text in enumerate(failed):
|
||||
msg += f"====== Try {i+1}:\n{text}\n"
|
||||
|
||||
raise ValueError(
|
||||
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
|
||||
"expected output. Please improve the prompt, increase the temperature, or "
|
||||
f"use different models.\n{msg}"
|
||||
)
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
decision = self.retry_for_expected(
|
||||
s.text_,
|
||||
self.adapt_params(16, [], sampling_params),
|
||||
partial(self._extract_choice, choices),
|
||||
)
|
||||
return decision, [1 if choice == decision else 0 for choice in choices]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
sampling_params: SamplingParams,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
if dtype is None:
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": s.text_,
|
||||
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
||||
}
|
||||
)
|
||||
return self._truncate_to_stop(res_json["generated_text"], stop), {}
|
||||
|
||||
if dtype in [str, "str", "string"]:
|
||||
stop = ['"']
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": f'{s.text_}"',
|
||||
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
||||
}
|
||||
)
|
||||
return (
|
||||
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
|
||||
{},
|
||||
)
|
||||
|
||||
if dtype in [int, "int"]:
|
||||
return (
|
||||
self.retry_for_expected(
|
||||
s.text_,
|
||||
self.adapt_params(max_tokens, stop, sampling_params),
|
||||
self._extract_int,
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
60
python/sglang/flush_cache.py
Normal file
60
python/sglang/flush_cache.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Flush cache in the backend by sending random requests."""
|
||||
import argparse
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
@sgl.function
|
||||
def flush_radix_cache(s, prompt):
|
||||
s += prompt + sgl.gen("flush", max_tokens=1, stop="END")
|
||||
|
||||
|
||||
def main(args, max_total_tokens, context_length, print_flag):
|
||||
backend = select_sglang_backend(args)
|
||||
flush_length = int(context_length * 0.8)
|
||||
batch_size = int(max_total_tokens / flush_length)
|
||||
prompt_length = flush_length * 2
|
||||
prompts = [
|
||||
" ".join(random.choices(string.ascii_letters, k=int(prompt_length)))
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
arguments = [{"prompt": prompts[i]} for i in range(batch_size)]
|
||||
|
||||
start_time = time.time()
|
||||
flush_radix_cache.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=1
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
if print_flag:
|
||||
print(
|
||||
f"Flush length: {flush_length}\n",
|
||||
f"Prompt length: {prompt_length}\n",
|
||||
f"Total Prompt letters: {batch_size * prompt_length}\n",
|
||||
f"Flush radix cache latency: {end_time - start_time:.3f}",
|
||||
sep="",
|
||||
)
|
||||
|
||||
# to prevent the backend still running
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def run_flush(args, max_total_tokens=20000, context_length=1024, print_flag=False):
|
||||
main(args, max_total_tokens, context_length, print_flag=print_flag)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-total-tokens", type=int, default=20000)
|
||||
parser.add_argument("--context-length", type=int, default=1024)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
random.seed(0)
|
||||
main(args, args.max_total_tokens, args.context_length, print_flag=True)
|
||||
28
python/sglang/global_config.py
Normal file
28
python/sglang/global_config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Global configurations"""
|
||||
|
||||
|
||||
class GlobalConfig:
|
||||
def __init__(self):
|
||||
# Verbosity level
|
||||
# 0: do not output anything
|
||||
# 2: output final text after every run
|
||||
self.verbosity = 0
|
||||
|
||||
self.default_backend = None
|
||||
|
||||
# Output configs
|
||||
self.skip_special_tokens_in_output = True
|
||||
|
||||
# Optimization configs
|
||||
self.eager_fill_image = False
|
||||
self.enable_prefix_sharing = True
|
||||
self.enable_parallel_encoding = True
|
||||
self.enable_parallel_decoding = True
|
||||
|
||||
# Choices: ["no_adjust", "adjust_cache"]
|
||||
# no_adjust: Do not adjust the position embedding of KV cache.
|
||||
# adjust_cache: Adjust the position embedding of KV cache.
|
||||
self.concate_and_append_mode = "no_adjust"
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
0
python/sglang/lang/__init__.py
Normal file
0
python/sglang/lang/__init__.py
Normal file
186
python/sglang/lang/chat_template.py
Normal file
186
python/sglang/lang/chat_template.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
|
||||
class ChatTemplateStyle(Enum):
|
||||
PLAIN = auto()
|
||||
LLAMA2 = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplate:
|
||||
name: str
|
||||
default_system_prompt: str
|
||||
role_prefix_and_suffix: Dict[str, Tuple[str]]
|
||||
image_token: str = "<image>"
|
||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||
|
||||
def get_prefix_and_suffix(self, role, hist_messages):
|
||||
if self.style == ChatTemplateStyle.PLAIN:
|
||||
return self.role_prefix_and_suffix[role]
|
||||
elif self.style == ChatTemplateStyle.LLAMA2:
|
||||
if len(hist_messages) == 0 and role == "system":
|
||||
return (
|
||||
self.role_prefix_and_suffix["user"][0]
|
||||
+ self.role_prefix_and_suffix["system"][0],
|
||||
self.role_prefix_and_suffix["system"][1],
|
||||
)
|
||||
elif (
|
||||
len(hist_messages) == 1
|
||||
and role == "user"
|
||||
and hist_messages[0]["content"] is not None
|
||||
):
|
||||
return ("", self.role_prefix_and_suffix["user"][1])
|
||||
return self.role_prefix_and_suffix[role]
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.style}")
|
||||
|
||||
def get_prompt(self, messages):
|
||||
prompt = ""
|
||||
for i in range(len(messages)):
|
||||
role, content = messages[i]["role"], messages[i]["content"]
|
||||
if role == "system" and content is None:
|
||||
content = self.default_system_prompt
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
||||
prompt += prefix + content + suffix
|
||||
return prompt
|
||||
|
||||
|
||||
chat_template_registry: Dict[str, ChatTemplate] = {}
|
||||
matching_function_registry: List[Callable] = []
|
||||
|
||||
|
||||
def register_chat_template(template):
|
||||
chat_template_registry[template.name] = template
|
||||
|
||||
|
||||
def register_chat_template_matching_function(func):
|
||||
matching_function_registry.append(func)
|
||||
|
||||
|
||||
def get_chat_template(name):
|
||||
return chat_template_registry[name]
|
||||
|
||||
|
||||
def get_chat_template_by_model_path(model_path):
|
||||
for matching_func in matching_function_registry:
|
||||
template = matching_func(model_path)
|
||||
if template is not None:
|
||||
return template
|
||||
return get_chat_template("default")
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="default",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("SYSTEM:", "\n"),
|
||||
"user": ("USER:", "\n"),
|
||||
"assistant": ("ASSISTANT:", "\n"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="claude",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("\n\nHuman: ", ""),
|
||||
"assistant": ("\n\nAssistant:", ""),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="vicuna_v1.1",
|
||||
default_system_prompt=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("USER:", " "),
|
||||
"assistant": ("ASSISTANT:", "</s>"),
|
||||
},
|
||||
image_token=" <image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-2-chat",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
|
||||
"user": ("[INST] ", " [/INST]"),
|
||||
"assistant": ("", " </s><s>"),
|
||||
},
|
||||
style=ChatTemplateStyle.LLAMA2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
if "vicuna" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "llama-2" in model_path and "chat" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if (
|
||||
"mistral" in model_path or "mixtral" in model_path
|
||||
) and "instruct" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if "codellama" in model_path and "instruct" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
if "tinyllama" in model_path.lower():
|
||||
return get_chat_template("chatml")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
{"role": "user", "content": "What can you do?"},
|
||||
{"role": "assistant", "content": "I can chat with you."},
|
||||
]
|
||||
|
||||
template = get_chat_template("llama-2-chat")
|
||||
print(template.get_prompt(messages))
|
||||
237
python/sglang/lang/compiler.py
Normal file
237
python/sglang/lang/compiler.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
|
||||
from sglang.lang.ir import (
|
||||
SamplingParams,
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglVariable,
|
||||
)
|
||||
|
||||
|
||||
def compile_func(function, backend):
|
||||
tracer = function.trace(backend=backend)
|
||||
compiler = CompiledFunction(tracer, function)
|
||||
return compiler
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, tracer, function):
|
||||
self.function = function
|
||||
|
||||
self.last_node = CompGraphNode(tracer.last_node)
|
||||
self.expr_to_node = {}
|
||||
self.build_graph(tracer)
|
||||
self.topological_sort()
|
||||
|
||||
def build_graph(self, tracer):
|
||||
self.nodes = [self.last_node]
|
||||
self.expr_to_node[tracer.last_node] = self.nodes[-1]
|
||||
|
||||
rename_pid = {}
|
||||
|
||||
visited = set([tracer.last_node])
|
||||
head = 0
|
||||
while head < len(self.nodes):
|
||||
cur_node = self.nodes[head]
|
||||
|
||||
# add prev node
|
||||
prev_node = cur_node.expr.prev_node
|
||||
if prev_node is not None:
|
||||
if prev_node not in visited:
|
||||
visited.add(prev_node)
|
||||
self.nodes.append(CompGraphNode(prev_node))
|
||||
self.expr_to_node[prev_node] = self.nodes[-1]
|
||||
cur_node.prev_node = self.expr_to_node[prev_node]
|
||||
self.expr_to_node[prev_node].add_next_node(cur_node)
|
||||
|
||||
# add source node
|
||||
if isinstance(cur_node.expr, SglVariable):
|
||||
if cur_node.expr.name in tracer.variables:
|
||||
source = tracer.variables[cur_node.expr.name].source
|
||||
else:
|
||||
source = cur_node.expr.source
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
self.nodes.append(CompGraphNode(source))
|
||||
self.expr_to_node[source] = self.nodes[-1]
|
||||
cur_node.source_node = self.expr_to_node[source]
|
||||
self.expr_to_node[source].add_next_node(cur_node)
|
||||
head += 1
|
||||
|
||||
# rename pid
|
||||
if cur_node.expr.pid not in rename_pid:
|
||||
rename_pid[cur_node.expr.pid] = len(rename_pid)
|
||||
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
|
||||
|
||||
def topological_sort(self):
|
||||
prevd = {}
|
||||
cand = Queue()
|
||||
for x in self.nodes:
|
||||
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
new_list = []
|
||||
while cand.qsize() > 0:
|
||||
head = cand.get()
|
||||
new_list.append(head)
|
||||
for x in head.next_nodes:
|
||||
prevd[x] -= 1
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
self.nodes = new_list
|
||||
|
||||
def print_graph(
|
||||
self,
|
||||
):
|
||||
for node in self.nodes:
|
||||
print(node)
|
||||
|
||||
def run_internal(
|
||||
self,
|
||||
backend,
|
||||
kwargs,
|
||||
default_sampling_para,
|
||||
):
|
||||
stream_executor_ids = set([x.expr.pid for x in self.nodes])
|
||||
stream_executors = {}
|
||||
for x in stream_executor_ids:
|
||||
arguments = kwargs if x == self.last_node.expr.pid else {}
|
||||
stream_executors[x] = StreamExecutor(
|
||||
backend, arguments, default_sampling_para, None, False
|
||||
)
|
||||
for node in self.nodes:
|
||||
se_id = node.expr.pid
|
||||
expr = node.expr
|
||||
if isinstance(expr, SglVariable):
|
||||
# Make a copy for SglVariable
|
||||
expr = SglVariable(expr.name, expr.source)
|
||||
expr.source_stream_executor = stream_executors[
|
||||
node.source_node.expr.pid
|
||||
]
|
||||
elif isinstance(expr, SglArgument):
|
||||
# Substitute SglArgument
|
||||
expr = kwargs[expr.name]
|
||||
stream_executors[se_id].submit(expr)
|
||||
for stream_executor in stream_executors.values():
|
||||
stream_executor.end()
|
||||
return ProgramState(stream_executors[self.last_node.expr.pid])
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
kwargs.update(self.function.bind_arguments)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
return self.run_internal(backend, kwargs, default_sampling_para)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
):
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_kwargs) > 1:
|
||||
pin_program(self.function, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
num_threads = min(num_threads, len(batch_kwargs))
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_kwargs:
|
||||
rets.append(
|
||||
self.run_internal(backend, arguments, default_sampling_para)
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
futures = []
|
||||
for arguments in batch_kwargs:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.run_internal, backend, arguments, default_sampling_para
|
||||
)
|
||||
)
|
||||
rets = [f.result() for f in futures]
|
||||
rets[-1].sync()
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
class CompGraphNode:
|
||||
def __init__(
|
||||
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
|
||||
):
|
||||
self.expr = expr
|
||||
self.next_nodes = next_nodes or []
|
||||
self.prev_node = prev_node
|
||||
self.source_node = source_node
|
||||
|
||||
def add_next_node(self, other):
|
||||
self.next_nodes.append(other)
|
||||
|
||||
def __repr__(self):
|
||||
re = f"stream {self.expr.pid:2d}: "
|
||||
re += f"%{self.expr.node_id} = "
|
||||
if self.prev_node is not None:
|
||||
re += f"%{self.prev_node.expr.node_id} + "
|
||||
re += repr(self.expr)
|
||||
return re
|
||||
697
python/sglang/lang/interpreter.py
Normal file
697
python/sglang/lang/interpreter.py
Normal file
@@ -0,0 +1,697 @@
|
||||
"""The interpreter that executes SGL programs"""
|
||||
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import tqdm
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglCommitLazy,
|
||||
SglConcateAndAppend,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglImage,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
)
|
||||
from sglang.utils import encode_image_base64
|
||||
|
||||
|
||||
def run_internal(state, program, func_args, func_kwargs, sync):
|
||||
try:
|
||||
state.ret_value = program.func(state, *func_args, **func_kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
state.stream_executor.end()
|
||||
|
||||
if sync:
|
||||
state.stream_executor.sync()
|
||||
|
||||
if global_config.verbosity >= 2:
|
||||
print(state.text())
|
||||
|
||||
|
||||
def run_program(
|
||||
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
|
||||
):
|
||||
assert backend is not None, "Please specify a backend"
|
||||
func_kwargs.update(program.bind_arguments)
|
||||
stream_executor = StreamExecutor(
|
||||
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
|
||||
)
|
||||
state = ProgramState(stream_executor)
|
||||
|
||||
if stream:
|
||||
t = threading.Thread(
|
||||
target=run_internal, args=(state, program, func_args, func_kwargs, sync)
|
||||
)
|
||||
t.start()
|
||||
return state
|
||||
else:
|
||||
run_internal(state, program, func_args, func_kwargs, sync)
|
||||
return state
|
||||
|
||||
|
||||
def run_program_batch(
|
||||
program,
|
||||
backend,
|
||||
batch_arguments,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
):
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_arguments) > 1:
|
||||
pin_program(program, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
num_threads = min(num_threads, len(batch_arguments))
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_arguments:
|
||||
rets.append(
|
||||
run_program(
|
||||
program, backend, (), arguments, default_sampling_para, False, False
|
||||
)
|
||||
)
|
||||
else:
|
||||
if progress_bar:
|
||||
pbar = tqdm.tqdm(total=len(batch_arguments))
|
||||
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
futures = []
|
||||
for arguments in batch_arguments:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
run_program,
|
||||
program,
|
||||
backend,
|
||||
(),
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
)
|
||||
if progress_bar:
|
||||
futures[-1].add_done_callback(lambda _: pbar.update())
|
||||
|
||||
rets = [f.result() for f in futures]
|
||||
rets[-1].sync()
|
||||
|
||||
if progress_bar:
|
||||
pbar.close()
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
def pin_program(program, backend):
|
||||
if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
|
||||
# TODO: handle multiple backends
|
||||
from sglang.lang.tracer import extract_prefix_by_tracing
|
||||
|
||||
prefix = extract_prefix_by_tracing(program, backend)
|
||||
if prefix and len(prefix) > 64:
|
||||
prefix_rid = backend.cache_prefix(prefix)
|
||||
program.pin_prefix_rid = prefix_rid
|
||||
return prefix_rid
|
||||
return None
|
||||
|
||||
|
||||
def unpin_program(program, backend):
|
||||
pass
|
||||
|
||||
|
||||
class StreamExecutor:
|
||||
"""A stream executor that executes SGL expressions in a background thread."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend,
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
chat_template,
|
||||
stream,
|
||||
use_thread=True,
|
||||
):
|
||||
self.sid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.default_sampling_para = default_sampling_para
|
||||
self.stream = stream
|
||||
|
||||
if hasattr(backend, "endpoint"):
|
||||
self.backend = backend.endpoint
|
||||
|
||||
self.variables = {} # Dict[name: str -> value: str]
|
||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||
self.meta_info = {} # Dict[name: str -> info: str]
|
||||
self.is_finished = False
|
||||
|
||||
# For completion
|
||||
self.text_ = "" # The full text
|
||||
|
||||
# For chat
|
||||
self.messages_ = [] # The messages in the OpenAI API format
|
||||
self.chat_template = chat_template or self.backend.get_chat_template()
|
||||
self.cur_role = None
|
||||
self.cur_role_begin_pos = None
|
||||
|
||||
# For vision
|
||||
self.images_ = []
|
||||
self.cur_images = []
|
||||
|
||||
# For fork/join
|
||||
self.fork_start_text_pos = None
|
||||
|
||||
# Worker thread
|
||||
self.use_thread = use_thread
|
||||
if self.use_thread:
|
||||
self.queue = queue.Queue()
|
||||
self.worker = threading.Thread(target=self._thread_worker_func)
|
||||
self.worker.start()
|
||||
|
||||
# For streaming
|
||||
if stream:
|
||||
self.stream_text_event = threading.Event()
|
||||
self.stream_var_event = {}
|
||||
else:
|
||||
self.stream_text_event = None
|
||||
self.stream_var_event = None
|
||||
|
||||
def submit(self, expr: SglExpr):
|
||||
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[expr.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[expr.name] = threading.Event()
|
||||
elif isinstance(expr, SglExprList):
|
||||
for e in expr.expr_list:
|
||||
if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)):
|
||||
self.variable_event[e.name] = threading.Event()
|
||||
if self.stream:
|
||||
self.stream_var_event[e.name] = threading.Event()
|
||||
|
||||
if self.use_thread:
|
||||
self.queue.put(expr)
|
||||
else:
|
||||
self._execute(expr)
|
||||
|
||||
def sync(self):
|
||||
if self.use_thread:
|
||||
self.queue.join()
|
||||
|
||||
def get_var(self, name):
|
||||
if name in self.variable_event:
|
||||
self.variable_event[name].wait()
|
||||
return self.variables[name]
|
||||
|
||||
def get_meta_info(self, name):
|
||||
if name in self.variable_event:
|
||||
self.variable_event[name].wait()
|
||||
ret = self.meta_info.get(name, None)
|
||||
return ret
|
||||
|
||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
||||
if number > 1:
|
||||
self.submit(SglCommitLazy())
|
||||
self.sync()
|
||||
|
||||
number = int(number)
|
||||
|
||||
exes = [
|
||||
StreamExecutor(
|
||||
self.backend,
|
||||
self.arguments,
|
||||
self.default_sampling_para,
|
||||
self.chat_template,
|
||||
self.stream,
|
||||
)
|
||||
for _ in range(number)
|
||||
]
|
||||
for i in range(number):
|
||||
exes[i].variables = dict(self.variables)
|
||||
exes[i].text_ = str(self.text_)
|
||||
exes[i].messages_ = list(self.messages_)
|
||||
exes[i].cur_role = self.cur_role
|
||||
exes[i].fork_start_text_pos = len(self.text_)
|
||||
|
||||
return exes
|
||||
|
||||
def text(self):
|
||||
self.sync()
|
||||
return self.text_
|
||||
|
||||
def messages(self):
|
||||
self.sync()
|
||||
return self.messages_
|
||||
|
||||
def end(self):
|
||||
if self.use_thread:
|
||||
if self.worker.is_alive():
|
||||
self.queue.put(None)
|
||||
self.backend.end_program(self)
|
||||
|
||||
def _thread_worker_func(self):
|
||||
while True:
|
||||
expr = self.queue.get()
|
||||
if expr is None:
|
||||
self.queue.task_done()
|
||||
break
|
||||
|
||||
self._execute(expr)
|
||||
self.queue.task_done()
|
||||
if self.stream_text_event:
|
||||
self.stream_text_event.set()
|
||||
|
||||
if self.stream_text_event:
|
||||
self.stream_text_event.set()
|
||||
|
||||
self.is_finished = True
|
||||
|
||||
def _execute(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
if isinstance(other, (SglConstantText, SglArgument)):
|
||||
self._execute_fill(other.value)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
elif isinstance(other, SglSelect):
|
||||
self._execute_select(other)
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute(x)
|
||||
elif isinstance(other, SglRoleBegin):
|
||||
self._execute_role_begin(other)
|
||||
elif isinstance(other, SglRoleEnd):
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglImage):
|
||||
self._execute_image(other)
|
||||
elif isinstance(other, SglVariable):
|
||||
self._execute_variable(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
self._execute_var_scope_begin(other)
|
||||
elif isinstance(other, SglVarScopeEnd):
|
||||
self._execute_var_scope_end(other)
|
||||
elif isinstance(other, SglCommitLazy):
|
||||
self._execute_commit_lazy_operations(other)
|
||||
elif isinstance(other, SglConcateAndAppend):
|
||||
if (
|
||||
global_config.enable_parallel_encoding
|
||||
and self.backend.support_concate_and_append
|
||||
):
|
||||
self._execute_concatenate_and_append_kv_cache(other)
|
||||
else:
|
||||
self._execute_concatenate_and_append_text(other)
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type(other)}")
|
||||
|
||||
def _execute_fill(self, value: str):
|
||||
value = str(value)
|
||||
self.text_ += value
|
||||
|
||||
def _execute_image(self, expr: SglImage):
|
||||
path = expr.path
|
||||
if isinstance(path, SglArgument):
|
||||
path = path.value
|
||||
|
||||
base64_data = encode_image_base64(path)
|
||||
|
||||
self.images_.append((path, base64_data))
|
||||
self.cur_images.append((path, base64_data))
|
||||
self.text_ += self.chat_template.image_token
|
||||
|
||||
# if global_config.eager_fill_image:
|
||||
# self.backend.fill_image(self)
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
self.text_ += comp
|
||||
|
||||
self.variables[name] = comp
|
||||
self.meta_info[name] = meta_info
|
||||
self.variable_event[name].set()
|
||||
else:
|
||||
generator = self.backend.generate_stream(
|
||||
self, sampling_params=sampling_params
|
||||
)
|
||||
|
||||
self.stream_var_event[name].set()
|
||||
|
||||
self.variables[name] = ""
|
||||
for comp, meta_info in generator:
|
||||
self.text_ += comp
|
||||
self.variables[name] += comp
|
||||
self.stream_var_event[name].set()
|
||||
self.stream_text_event.set()
|
||||
|
||||
self.meta_info[name] = meta_info
|
||||
|
||||
self.variable_event[name].set()
|
||||
self.stream_var_event[name].set()
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
decision, scores = self.backend.select(self, expr.choices, expr.temperature)
|
||||
if expr.name is not None:
|
||||
name = expr.name
|
||||
self.variables[name] = decision
|
||||
self.variable_event[name].set()
|
||||
self.text_ += decision
|
||||
|
||||
def _execute_variable(self, expr: SglVariable):
|
||||
src_executor = expr.source_stream_executor
|
||||
value = src_executor.get_var(expr.name)
|
||||
self._execute_fill(value)
|
||||
|
||||
def _execute_role_begin(self, expr: SglRoleBegin):
|
||||
assert self.cur_role is None, "Nested roles are not allowed."
|
||||
|
||||
if len(self.messages_) == 0 and expr.role != "system":
|
||||
# Insert the default system message
|
||||
default_system = self.chat_template.default_system_prompt
|
||||
if default_system:
|
||||
self._execute_role_begin(SglRoleBegin("system"))
|
||||
self._execute_fill(default_system)
|
||||
self._execute_role_end(SglRoleEnd("system"))
|
||||
|
||||
self.cur_role = expr.role
|
||||
|
||||
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
self.cur_role_begin_pos = len(self.text_)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
||||
|
||||
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
||||
self._execute_fill(suffix)
|
||||
|
||||
if self.cur_images:
|
||||
# OpenAI vision API format
|
||||
last_msg = {
|
||||
"role": expr.role,
|
||||
"content": [{"type": "text", "text": new_text}],
|
||||
}
|
||||
for (image_path, image_base64_data) in self.cur_images:
|
||||
last_msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64_data}"
|
||||
},
|
||||
}
|
||||
)
|
||||
self.messages_.append(last_msg)
|
||||
self.cur_images = []
|
||||
else:
|
||||
self.messages_.append({"role": expr.role, "content": new_text})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
||||
self.variables[expr.name] = int(len(self.text_))
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
self.variables[expr.name] = self.text_[self.variables[expr.name] :]
|
||||
self.variable_event[expr.name].set()
|
||||
|
||||
def _execute_commit_lazy_operations(self, expr: SglCommitLazy):
|
||||
self.backend.commit_lazy_operations(self)
|
||||
|
||||
def _execute_concatenate_and_append_text(self, expr: SglConcateAndAppend):
|
||||
new_text = ""
|
||||
for s in expr.states:
|
||||
exe = s.stream_executor
|
||||
exe.sync()
|
||||
new_text += exe.text_[exe.fork_start_text_pos :]
|
||||
|
||||
self._execute_fill(new_text)
|
||||
|
||||
def _execute_concatenate_and_append_kv_cache(self, expr: SglConcateAndAppend):
|
||||
self_len = len(self.text_)
|
||||
|
||||
for i, s in enumerate(expr.states):
|
||||
exe = s.stream_executor
|
||||
exe.submit(SglCommitLazy())
|
||||
|
||||
for i, s in enumerate(expr.states):
|
||||
exe = s.stream_executor
|
||||
exe.sync()
|
||||
assert exe.fork_start_text_pos == self_len
|
||||
self.text_ += exe.text_[exe.fork_start_text_pos :]
|
||||
|
||||
src_rids = [state.stream_executor.sid for state in expr.states]
|
||||
self.backend.concatenate_and_append(src_rids, self.sid)
|
||||
|
||||
def _resolve_sampling_params(self, sampling_params):
|
||||
clone = None
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"dtype",
|
||||
"regex",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
if clone is None:
|
||||
clone = self.default_sampling_para.clone()
|
||||
setattr(clone, item, value)
|
||||
return clone or self.default_sampling_para
|
||||
|
||||
def __del__(self):
|
||||
self.end()
|
||||
|
||||
|
||||
class ProgramState:
|
||||
"""The state of an SGL program."""
|
||||
|
||||
def __init__(self, stream_executor: StreamExecutor):
|
||||
self.stream_executor = stream_executor
|
||||
|
||||
def _role_common(self, name: str, expr: Optional[SglExpr] = None):
|
||||
if expr is not None:
|
||||
self.stream_executor.submit(
|
||||
SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
||||
)
|
||||
else:
|
||||
|
||||
@contextmanager
|
||||
def role_scope():
|
||||
self.stream_executor.submit(SglRoleBegin(name))
|
||||
yield
|
||||
self.stream_executor.submit(SglRoleEnd(name))
|
||||
|
||||
return role_scope()
|
||||
|
||||
def system(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("system", expr)
|
||||
|
||||
def user(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("user", expr)
|
||||
|
||||
def assistant(self, expr: Optional[SglExpr] = None):
|
||||
return self._role_common("assistant", expr)
|
||||
|
||||
@contextmanager
|
||||
def var_scope(self, name: str):
|
||||
self.stream_executor.submit(SglVarScopeBegin(name))
|
||||
yield
|
||||
self.stream_executor.submit(SglVarScopeEnd(name))
|
||||
|
||||
def fork(self, number: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||
stream_executors = self.stream_executor.fork(number, position_ids_offset)
|
||||
states = [ProgramState(x) for x in stream_executors]
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
return state_group
|
||||
|
||||
@contextmanager
|
||||
def copy(self, position_ids_offset: Optional[List[int]] = None):
|
||||
state_group = self.fork(1, position_ids_offset)
|
||||
try:
|
||||
yield state_group[0]
|
||||
finally:
|
||||
state_group.join()
|
||||
|
||||
def text(self):
|
||||
return self.stream_executor.text()
|
||||
|
||||
def messages(self):
|
||||
return self.stream_executor.messages()
|
||||
|
||||
def sync(self):
|
||||
return self.stream_executor.sync()
|
||||
|
||||
def text_iter(self, var_name=None):
|
||||
if self.stream_executor.stream:
|
||||
prev = 0
|
||||
if var_name is None:
|
||||
event = self.stream_executor.stream_text_event
|
||||
while True:
|
||||
event.wait()
|
||||
event.clear()
|
||||
out = str(self.stream_executor.text_[prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.is_finished:
|
||||
break
|
||||
else:
|
||||
event = self.stream_executor.stream_var_event[var_name]
|
||||
while True:
|
||||
event.wait()
|
||||
event.clear()
|
||||
out = str(self.stream_executor.variables[var_name][prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.variable_event[var_name].is_set():
|
||||
break
|
||||
else:
|
||||
if var_name is None:
|
||||
yield self.text()
|
||||
else:
|
||||
yield self.get_var(name)
|
||||
|
||||
async def text_async_iter(self, var_name=None):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if self.stream_executor.stream:
|
||||
prev = 0
|
||||
if var_name is None:
|
||||
event = self.stream_executor.stream_text_event
|
||||
while True:
|
||||
await loop.run_in_executor(None, event.wait)
|
||||
event.clear()
|
||||
out = str(self.stream_executor.text_[prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.is_finished:
|
||||
break
|
||||
else:
|
||||
event = self.stream_executor.stream_var_event[var_name]
|
||||
while True:
|
||||
await loop.run_in_executor(None, event.wait)
|
||||
event.clear()
|
||||
out = str(self.stream_executor.variables[var_name][prev:])
|
||||
prev += len(out)
|
||||
if out:
|
||||
yield out
|
||||
if self.stream_executor.variable_event[var_name].is_set():
|
||||
break
|
||||
else:
|
||||
if var_name is None:
|
||||
yield self.text()
|
||||
else:
|
||||
yield self.get_var(name)
|
||||
|
||||
def get_var(self, name):
|
||||
return self.stream_executor.get_var(name)
|
||||
|
||||
def get_meta_info(self, name):
|
||||
return self.stream_executor.get_meta_info(name)
|
||||
|
||||
def __iadd__(self, other):
|
||||
self.stream_executor.submit(other)
|
||||
return self
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.get_var(name)
|
||||
|
||||
def __del__(self):
|
||||
self.stream_executor.end()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msgs = self.messages()
|
||||
ret = ""
|
||||
for msg in msgs:
|
||||
ret += msg["role"] + ":\n" + msg["content"] + "\n"
|
||||
return ret
|
||||
|
||||
|
||||
class ProgramStateGroup:
|
||||
def __init__(
|
||||
self, states: List[ProgramState], src_state: Optional[ProgramState] = None
|
||||
):
|
||||
self.states = states
|
||||
self.src_state = src_state
|
||||
|
||||
def join(self, mode: str = "gather_variable"):
|
||||
if mode == "gather_variable":
|
||||
# Copy variables back
|
||||
src_vars = self.src_state.stream_executor.variables
|
||||
src_var_set = set(src_vars.keys())
|
||||
for child_state in self.states:
|
||||
child_state.stream_executor.sync()
|
||||
child_vars = child_state.stream_executor.variables
|
||||
new_vars = set(child_vars.keys()) - src_var_set
|
||||
|
||||
for k in new_vars:
|
||||
if k in src_vars:
|
||||
src_vars[k].append(child_vars[k])
|
||||
else:
|
||||
src_vars[k] = [child_vars[k]]
|
||||
elif mode == "concate_and_append":
|
||||
# Concatenate and append KV cache
|
||||
self.src_state += SglConcateAndAppend(self.states)
|
||||
# Need a sync here. Otherwise, `states` can be deleted.
|
||||
self.src_state.stream_executor.sync()
|
||||
else:
|
||||
raise ValueError(f"Invalid join mode: {mode}")
|
||||
|
||||
for s in self.states:
|
||||
s.stream_executor.end()
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.states[i]
|
||||
|
||||
def __setitem__(self, i: int, value):
|
||||
assert self.states[i] == value
|
||||
|
||||
def __iadd__(self, other):
|
||||
if isinstance(other, Callable):
|
||||
# lambda function
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other(i)
|
||||
elif isinstance(other, SglExpr):
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other
|
||||
elif isinstance(other, (list, tuple)):
|
||||
for i in range(len(self.states)):
|
||||
self.states[i] += other[i]
|
||||
else:
|
||||
raise ValueError(f"Invalid value: {other}")
|
||||
|
||||
return self
|
||||
442
python/sglang/lang/ir.py
Normal file
442
python/sglang/lang/ir.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""The intermediate representation."""
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingParams:
|
||||
max_new_tokens: int = 16
|
||||
stop: Union[str, List[str]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
|
||||
def clone(self):
|
||||
return SamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.stop,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
self.frequency_penalty,
|
||||
self.presence_penalty,
|
||||
)
|
||||
|
||||
def to_openai_kwargs(self):
|
||||
# OpenAI does not support top_k, so we drop it here
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_anthropic_kwargs(self):
|
||||
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
||||
return {
|
||||
"max_tokens_to_sample": self.max_new_tokens,
|
||||
"stop_sequences": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"stop": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"regex": self.regex,
|
||||
}
|
||||
|
||||
|
||||
class SglFunction:
|
||||
def __init__(self, func, bind_arguments=None):
|
||||
self.func = func
|
||||
self.bind_arguments = bind_arguments or {}
|
||||
self.pin_prefix_rid = None
|
||||
|
||||
# Parse arguments
|
||||
argspec = inspect.getfullargspec(func)
|
||||
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
||||
self.arg_names = argspec.args[1:]
|
||||
|
||||
def bind(self, **kwargs):
|
||||
assert all(key in self.arg_names for key in kwargs)
|
||||
|
||||
new_bind_dict = {**self.bind_arguments, **kwargs}
|
||||
return SglFunction(self.func, bind_arguments=new_bind_dict)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program_batch
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
return run_program_batch(
|
||||
self,
|
||||
backend,
|
||||
batch_kwargs,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
)
|
||||
|
||||
def trace(self, *, backend=None, **kwargs):
|
||||
from sglang.lang.tracer import trace_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return trace_program(self, kwargs, backend)
|
||||
|
||||
def pin(self, backend=None):
|
||||
from sglang.lang.interpreter import pin_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return pin_program(self, backend)
|
||||
|
||||
def unpin(self, backend=None):
|
||||
from sglang.lang.interpreter import unpin_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return unpin_program(self, backend)
|
||||
|
||||
def compile(self, *, backend=None):
|
||||
from sglang.lang.compiler import compile_func
|
||||
|
||||
return compile_func(self, backend)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
from sglang.lang.tracer import TracingScope
|
||||
|
||||
tracing_scope = TracingScope.get_current_scope()
|
||||
if tracing_scope is None:
|
||||
return self.run(*args, **kwargs)
|
||||
else:
|
||||
kwargs["backend"] = tracing_scope.tracer_state.backend
|
||||
return self.trace(*args, **kwargs)
|
||||
|
||||
|
||||
class SglExpr:
|
||||
node_ct = 0
|
||||
|
||||
def __init__(self):
|
||||
self.node_id = SglExpr.node_ct
|
||||
self.prev_node = None
|
||||
self.pid = None
|
||||
SglExpr.node_ct += 1
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr)
|
||||
|
||||
return self.concatenate_ir(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
return self.concatenate_ir(other, self)
|
||||
|
||||
def concatenate_ir(self, a, b):
|
||||
if isinstance(a, SglExprList):
|
||||
if isinstance(b, SglExprList):
|
||||
return SglExprList(a.expr_list + b.expr_list)
|
||||
else:
|
||||
return SglExprList(a.expr_list + [b])
|
||||
elif isinstance(b, SglExprList):
|
||||
return SglExprList([a] + b.expr_list)
|
||||
|
||||
return SglExprList([a, b])
|
||||
|
||||
def print_graph_dfs(self):
|
||||
ret = [""]
|
||||
visited = set()
|
||||
|
||||
def dfs_print(x):
|
||||
if x is None or x in visited:
|
||||
return
|
||||
visited.add(x)
|
||||
|
||||
# Print dependency
|
||||
if x.prev_node is not None:
|
||||
dfs_print(x.prev_node)
|
||||
|
||||
if isinstance(x, SglExprList):
|
||||
for y in x.expr_list:
|
||||
dfs_print(y)
|
||||
# elif isinstance(x, SglRole):
|
||||
# dfs_print(x.expr)
|
||||
elif isinstance(x, SglVariable):
|
||||
dfs_print(x.source)
|
||||
|
||||
# Print the node itself
|
||||
if isinstance(x, (SglFork, SglGetForkItem)):
|
||||
ret[0] += f"%{x.node_id} = {x}\n"
|
||||
else:
|
||||
if x.prev_node is not None:
|
||||
ret[0] += (
|
||||
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
|
||||
)
|
||||
else:
|
||||
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
|
||||
|
||||
dfs_print(self)
|
||||
return ret[0]
|
||||
|
||||
|
||||
class SglExprList(SglExpr):
|
||||
def __init__(self, expr_list: List[SglExpr]):
|
||||
super().__init__()
|
||||
self.expr_list = expr_list
|
||||
|
||||
def __repr__(self):
|
||||
return f"ExprList({self.expr_list})"
|
||||
|
||||
|
||||
class SglArgument(SglExpr):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Argument(name={self.name}, value={repr(self.value)})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.value)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.value[i]
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return self.value
|
||||
|
||||
def __format__(self, *args):
|
||||
raise TypeError(
|
||||
"Cannot put argument inside a f-string. "
|
||||
"This is not compatible with the tracer. "
|
||||
)
|
||||
|
||||
|
||||
class SglImage(SglExpr):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglImage({self.path})"
|
||||
|
||||
|
||||
class SglGen(SglExpr):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
max_new_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
dtype,
|
||||
regex,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Gen('{self.name}')"
|
||||
|
||||
|
||||
class SglConstantText(SglExpr):
|
||||
def __init__(self, value):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Constant({repr(self.value)})"
|
||||
|
||||
|
||||
class SglRoleBegin(SglExpr):
|
||||
def __init__(self, role):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleBegin({self.role})"
|
||||
|
||||
|
||||
class SglRoleEnd(SglExpr):
|
||||
def __init__(self, role):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleEnd({self.role})"
|
||||
|
||||
|
||||
class SglSelect(SglExpr):
|
||||
def __init__(self, name, choices, temperature):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.choices = choices
|
||||
self.temperature = temperature
|
||||
|
||||
def __repr__(self):
|
||||
return f"Select({self.name}, choices={self.choices})"
|
||||
|
||||
|
||||
class SglFork(SglExpr):
|
||||
def __init__(self, number, position_ids_offset=None):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
self.position_ids_offset = position_ids_offset
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
|
||||
f"position_ids_offset={self.position_ids_offset})"
|
||||
)
|
||||
|
||||
|
||||
class SglGetForkItem(SglExpr):
|
||||
def __init__(self, index):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
|
||||
|
||||
|
||||
class SglVariable(SglExpr):
|
||||
def __init__(self, name, source):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.source = source
|
||||
|
||||
def __repr__(self):
|
||||
return f"Variable('{self.name}', source=%{self.source.node_id})"
|
||||
|
||||
|
||||
class SglVarScopeBegin(SglExpr):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeBegin('{self.name}')"
|
||||
|
||||
|
||||
class SglVarScopeEnd(SglExpr):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeEnd('{self.name}')"
|
||||
|
||||
|
||||
class SglConcateAndAppend(SglExpr):
|
||||
def __init__(self, states):
|
||||
super().__init__()
|
||||
self.states = states
|
||||
|
||||
def __repr__(self):
|
||||
return f"ConcatenateAndAppend('{self.states}')"
|
||||
|
||||
|
||||
class SglCommitLazy(SglExpr):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return f"CommitLazy()"
|
||||
279
python/sglang/lang/tracer.py
Normal file
279
python/sglang/lang/tracer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Tracing a program."""
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglCommitLazy,
|
||||
SglConcateAndAppend,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFork,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglGetForkItem,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
)
|
||||
|
||||
|
||||
class StopTracing(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_prefix_by_tracing(program, backend):
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
|
||||
arguments = dummy_arguments
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except StopTracing:
|
||||
pass
|
||||
|
||||
# Run and cache prefix
|
||||
prefix = ""
|
||||
for expr in tracer.flatten_nodes():
|
||||
if isinstance(expr, SglConstantText):
|
||||
prefix += expr.value
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def trace_program(program, arguments, backend):
|
||||
# Create dummy backend
|
||||
if backend is None:
|
||||
backend = BaseBackend()
|
||||
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {
|
||||
name: SglArgument(name, None)
|
||||
for name in program.arg_names
|
||||
if name not in arguments
|
||||
}
|
||||
arguments.update(dummy_arguments)
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
return tracer
|
||||
|
||||
|
||||
class TracerProgramState(ProgramState):
|
||||
def __init__(self, backend, arguments, only_trace_prefix):
|
||||
self.pid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.only_trace_prefix = only_trace_prefix
|
||||
|
||||
if hasattr(backend, "endpoint"):
|
||||
self.backend = backend.endpoint
|
||||
|
||||
self.nodes = []
|
||||
self.last_node = None
|
||||
self.variables = {}
|
||||
self.ret_value = None
|
||||
|
||||
# For completion
|
||||
|
||||
# For chat
|
||||
self.messages_ = []
|
||||
self.cur_role = None
|
||||
self.chat_template = self.backend.get_chat_template()
|
||||
|
||||
# For multi states
|
||||
self.child_states = []
|
||||
|
||||
cur_scope = TracingScope.get_current_scope()
|
||||
if cur_scope is not None:
|
||||
cur_scope.add_child_state(self)
|
||||
|
||||
##################################
|
||||
########### Public API ###########
|
||||
##################################
|
||||
|
||||
def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
|
||||
fork_node = SglFork(number)
|
||||
fork_node.prev_node = self.last_node
|
||||
|
||||
states = [
|
||||
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
||||
for _ in range(number)
|
||||
]
|
||||
|
||||
for i in range(number):
|
||||
node = SglGetForkItem(i)
|
||||
node.prev_node = fork_node
|
||||
states[i].last_node = node
|
||||
states[i].variables = dict(self.variables)
|
||||
states[i].messages_ = list(self.messages_)
|
||||
states[i].cur_role = self.cur_role
|
||||
states[i].chat_template = self.chat_template
|
||||
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
|
||||
return state_group
|
||||
|
||||
##################################
|
||||
########## Internal API ##########
|
||||
##################################
|
||||
|
||||
def _append_node(self, other: SglExpr):
|
||||
self.nodes.append(other)
|
||||
other.prev_node = self.last_node
|
||||
self.last_node = other
|
||||
|
||||
def _execute(self, other: SglExpr):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
|
||||
other.pid = self.pid
|
||||
|
||||
if isinstance(other, SglConstantText):
|
||||
self._execute_fill(other)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
elif isinstance(other, SglSelect):
|
||||
self._execute_select(other)
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute(x)
|
||||
elif isinstance(other, SglRoleBegin):
|
||||
self._execute_role_begin(other)
|
||||
elif isinstance(other, SglRoleEnd):
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
self._execute_var_scope_begin(other)
|
||||
elif isinstance(other, SglVarScopeEnd):
|
||||
self._execute_var_scope_end(other)
|
||||
else:
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
else:
|
||||
self._append_node(other)
|
||||
|
||||
return self
|
||||
|
||||
def __iadd__(self, other):
|
||||
self._execute(other)
|
||||
return self
|
||||
|
||||
def _execute_fill(self, expr: SglConstantText):
|
||||
if isinstance(expr, str):
|
||||
expr = SglConstantText(expr)
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
name = (
|
||||
expr.name if expr.name is not None else "select_" + str(len(self.variables))
|
||||
)
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_role_begin(self, expr: SglRoleBegin):
|
||||
assert self.cur_role is None, "Nested roles are not allowed."
|
||||
|
||||
if len(self.messages_) == 0 and expr.role != "system":
|
||||
# Insert default system message
|
||||
default_system = self.chat_template.default_system_prompt
|
||||
if default_system:
|
||||
self._execute_role_begin(SglRoleBegin("system"))
|
||||
self._execute_fill(default_system)
|
||||
self._execute_role_end(SglRoleEnd("system"))
|
||||
|
||||
self.cur_role = expr.role
|
||||
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(suffix)
|
||||
|
||||
self.messages_.append({"role": expr.role, "content": ""})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
new_node = SglVariable(name, source=self.last_node)
|
||||
self.variables[name] = new_node
|
||||
|
||||
def get_var(self, name):
|
||||
ret = self.arguments.get(name, None)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
v = self.variables[name]
|
||||
return SglVariable(v.name, v.source)
|
||||
|
||||
def flatten_nodes(self):
|
||||
def traverse(cur):
|
||||
if isinstance(cur, SglExprList):
|
||||
for child in cur.expr_list:
|
||||
traverse(child)
|
||||
else:
|
||||
ret.append(cur)
|
||||
|
||||
ret = []
|
||||
for x in self.nodes:
|
||||
traverse(x)
|
||||
return ret
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
class TracingScope:
|
||||
cur_scope = None
|
||||
|
||||
def __init__(self, tracer_state: TracerProgramState):
|
||||
self.tracer_state = tracer_state
|
||||
self.last_scope = TracingScope.cur_scope
|
||||
|
||||
def __enter__(self):
|
||||
TracingScope.cur_scope = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
TracingScope.cur_scope = self.last_scope
|
||||
|
||||
@staticmethod
|
||||
def get_current_scope():
|
||||
return TracingScope.cur_scope
|
||||
|
||||
def add_child_state(self, state: TracerProgramState):
|
||||
cur_scope = self
|
||||
while cur_scope != None:
|
||||
cur_scope.tracer_state.child_states.append(state)
|
||||
cur_scope = cur_scope.last_scope
|
||||
11
python/sglang/launch_server.py
Normal file
11
python/sglang/launch_server.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import argparse
|
||||
|
||||
from sglang.srt.server import ServerArgs, launch_server
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
launch_server(server_args, None)
|
||||
385
python/sglang/srt/constrained/fsm.py
Normal file
385
python/sglang/srt/constrained/fsm.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
|
||||
from typing import List, NewType, Protocol
|
||||
|
||||
import interegular
|
||||
from lark import Lark
|
||||
|
||||
# from outlines.fsm.parsing import PartialLark
|
||||
from sglang.srt.constrained.regex import (
|
||||
create_fsm_index_tokenizer,
|
||||
make_deterministic_fsm,
|
||||
)
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
FSMState = NewType("FSMState", int)
|
||||
|
||||
|
||||
class FSM(Protocol):
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
...
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
...
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class StopAtTokenFSM(FSM):
|
||||
"""FSM to generate text until a specified token id is generated or
|
||||
a specified number of tokens has been generated.
|
||||
|
||||
Text is usually produced until the EOS token is generated by the
|
||||
model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: "Tokenizer",
|
||||
stop_token_id: int,
|
||||
):
|
||||
self.stop_token_id = stop_token_id
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.final_states = {1}
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
When in the initial state we allow every token to be generated.
|
||||
In the final state the only allowed token is `stop_token_id`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if state == 0:
|
||||
return list(self.vocabulary)
|
||||
else:
|
||||
return [self.stop_token_id]
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
The FSM stays in the initial state `0` unless the specified stop token
|
||||
has been generated or the maximum number of tokens has been reached. In
|
||||
which case the FSM moves to the final state `1`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.stop_token_id:
|
||||
return FSMState(1)
|
||||
|
||||
return FSMState(0)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
|
||||
|
||||
class RegexFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a regular expression."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
(
|
||||
self.states_to_token_maps,
|
||||
self.empty_token_ids,
|
||||
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
||||
|
||||
# We make sure that it is possible to generate strings in the language
|
||||
# of the regular expression with the tokens present in the model's
|
||||
# vocabulary.
|
||||
if not any(
|
||||
regex_fsm.finals.intersection(v.values())
|
||||
for v in self.states_to_token_maps.values()
|
||||
):
|
||||
raise ValueError(
|
||||
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
||||
)
|
||||
|
||||
self.final_states = regex_fsm.finals | {
|
||||
-1
|
||||
} # Include the EOS token in final states
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.end_token_id = tokenizer.eos_token_id
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
The initialization of the FSM builds an index which maps FSM states to a
|
||||
map from authorized tokens to the state in which the FSM needs to move
|
||||
if said token is generated. Therefore the authorized tokens at the
|
||||
current state are the keys of the map returned by the value of the index
|
||||
for current state.
|
||||
|
||||
If the current state is not contained in the end this means that we are
|
||||
in a final state of the FSM. We only authorize EOS tokens in the final
|
||||
state.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
next_tokens_to_end_states = self.states_to_token_maps.get(state)
|
||||
|
||||
if next_tokens_to_end_states is None:
|
||||
return [self.end_token_id]
|
||||
else:
|
||||
return list(next_tokens_to_end_states.keys())
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
We use the index to determine to which state the FSM should transition
|
||||
given the token that was just generated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.end_token_id:
|
||||
return FSMState(-1)
|
||||
|
||||
last_token_to_end_state = self.states_to_token_maps[state]
|
||||
next_state = last_token_to_end_state.get(token_id)
|
||||
if next_state is None:
|
||||
next_state = -1
|
||||
|
||||
return FSMState(next_state)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
|
||||
|
||||
class CFGFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a context-free grammar."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
# self.parser = PartialLark(cfg_string, parser="lalr")
|
||||
self.parser = Lark(
|
||||
cfg_string,
|
||||
parser="lalr",
|
||||
lexer="contextual",
|
||||
propagate_positions=False,
|
||||
maybe_placeholders=False,
|
||||
regex=True,
|
||||
)
|
||||
self.terminal_regexps = dict()
|
||||
for terminal in self.parser.terminals:
|
||||
if terminal.pattern is not None:
|
||||
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
|
||||
self.terminal_regexps["$END"] = tokenizer.eos_token
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.num_tokens_generated = 0
|
||||
self.generations: List[str] = []
|
||||
self.regex_fsms: List[RegexFSM] = []
|
||||
self.reset_state: List[bool] = []
|
||||
self.allow_eos: List[bool] = []
|
||||
self.done: List[bool] = []
|
||||
|
||||
def _set_next_regex_fsm(self, idx: int = 0) -> None:
|
||||
"""Use the CFG incremental parser to set the next regex FSM.
|
||||
|
||||
Check what the CFG incremental parser proposes next.
|
||||
If the only proposal is the EOS token,
|
||||
we set the state to done and return.
|
||||
If there are other proposals,
|
||||
we set a new regex FSM and return.
|
||||
|
||||
"""
|
||||
interactive = self.parser.parse_interactive(self.generations[idx])
|
||||
interactive.exhaust_lexer()
|
||||
options = {self.terminal_regexps[x] for x in interactive.accepts()}
|
||||
|
||||
if self.terminal_regexps["$END"] in options:
|
||||
options.remove(self.terminal_regexps["$END"])
|
||||
if len(options) == 0:
|
||||
self.done[idx] = True
|
||||
return
|
||||
self.allow_eos[idx] = True
|
||||
options.add("")
|
||||
assert len(options) > 1
|
||||
|
||||
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
|
||||
args = (
|
||||
regex_string,
|
||||
self.tokenizer,
|
||||
)
|
||||
if len(self.regex_fsms) <= idx:
|
||||
self.regex_fsms.append(RegexFSM(*args))
|
||||
else:
|
||||
self.regex_fsms[idx] = RegexFSM(*args)
|
||||
self.reset_state[idx] = True
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
Upon initialization, the CFG incremental parser is used to determine the first regex.
|
||||
|
||||
This regex is used for proposals until either:
|
||||
- the regex is exhausted, and its only remaining option is the EOS token,
|
||||
in which case we always transition to the next regex
|
||||
- the regex can be exhausted, but the EOS token is not the only remaining option,
|
||||
in which case we transition to the next regex with probability P (TODO)
|
||||
or remove the possibility of generating the EOS token and continue with the current regex
|
||||
|
||||
The CFG incremental parser is allowed to propose the EOS token from any final state,
|
||||
and once it is generated, the FSM will continue to always generate the EOS token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if len(self.generations) <= idx:
|
||||
self.generations.append("")
|
||||
self.reset_state.append(False)
|
||||
self.allow_eos.append(False)
|
||||
self.done.append(False)
|
||||
|
||||
if len(self.regex_fsms) > idx:
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.tokenizer.eos_token_id not in proposal:
|
||||
return proposal
|
||||
if set(proposal) != {self.tokenizer.eos_token_id}:
|
||||
if False: # TODO: THIS NEEDS TO BE SAMPLED
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
return proposal
|
||||
|
||||
self._set_next_regex_fsm(idx)
|
||||
|
||||
if self.done[idx]:
|
||||
return [self.tokenizer.eos_token_id]
|
||||
|
||||
if self.reset_state[idx]:
|
||||
state = FSMState(0)
|
||||
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.allow_eos[idx]:
|
||||
self.allow_eos[idx] = False
|
||||
else:
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
assert len(proposal) > 0
|
||||
return proposal
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
Transitions the underlying regex FSM to its next state.
|
||||
If at max tokens or EOS token, transition permanently to the final state.
|
||||
Update stored partial generations for subsequent incremental parsing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
if token_id == self.tokenizer.eos_token_id:
|
||||
self.done[idx] = True
|
||||
return FSMState(-1)
|
||||
if self.reset_state[idx]:
|
||||
self.reset_state[idx] = False
|
||||
state = FSMState(0)
|
||||
|
||||
self.generations[idx] += self.tokenizer.decode([token_id])[0]
|
||||
|
||||
return self.regex_fsms[idx].next_state(state, token_id, idx)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Return whether the current state of the FSM is a final state."""
|
||||
return self.done[idx]
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
|
||||
self.num_tokens_generated = 0
|
||||
self.generations = []
|
||||
self.regex_fsms = []
|
||||
self.reset_state = []
|
||||
self.done = []
|
||||
41
python/sglang/srt/constrained/fsm_cache.py
Normal file
41
python/sglang/srt/constrained/fsm_cache.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import threading
|
||||
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
|
||||
def get_fsm(regex, tokenizer, fsm_cache_entry):
|
||||
outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex, outlines_tokenizer)
|
||||
fsm_cache_entry.fsm = fsm
|
||||
fsm_cache_entry.event.set()
|
||||
|
||||
|
||||
class FSMCacheEntry:
|
||||
def __init__(self):
|
||||
self.fsm = None
|
||||
self.event = threading.Event()
|
||||
|
||||
|
||||
class FSMCache:
|
||||
def __init__(self, tokenizer):
|
||||
self.cache = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def init_fsm_in_background(self, regex):
|
||||
if regex not in self.cache:
|
||||
self.cache[regex] = FSMCacheEntry()
|
||||
threading.Thread(
|
||||
target=get_fsm,
|
||||
args=(
|
||||
regex,
|
||||
self.tokenizer,
|
||||
self.cache[regex],
|
||||
),
|
||||
).start()
|
||||
|
||||
def get_fsm(self, regex):
|
||||
self.init_fsm_in_background(regex)
|
||||
entry = self.cache[regex]
|
||||
entry.event.wait()
|
||||
return entry.fsm
|
||||
586
python/sglang/srt/constrained/regex.py
Normal file
586
python/sglang/srt/constrained/regex.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
|
||||
from collections import namedtuple
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Generator, List, Sequence, Set, Tuple
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
|
||||
from numba.typed.typedobjectutils import _nonoptional
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class BetterAlphabet(Alphabet):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert anything_else in self._symbol_mapping
|
||||
self.anything_value = self._symbol_mapping[anything_else]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._symbol_mapping.get(item, self.anything_value)
|
||||
|
||||
def copy(self):
|
||||
return BetterAlphabet(self._symbol_mapping.copy())
|
||||
|
||||
|
||||
class BetterFSM(FSM):
|
||||
flat_transition_map: Dict[Tuple[int, int], int]
|
||||
trans_key_to_states: Dict[int, List[int]]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if not isinstance(self.alphabet, BetterAlphabet):
|
||||
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
|
||||
|
||||
flat_transition_map = {}
|
||||
trans_key_to_states = {}
|
||||
for from_state, trans_map in self.map.items():
|
||||
for trans_key, to_state in trans_map.items():
|
||||
flat_transition_map[(from_state, trans_key)] = to_state
|
||||
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
|
||||
|
||||
self.__dict__["trans_key_to_states"] = trans_key_to_states
|
||||
self.__dict__["flat_transition_map"] = flat_transition_map
|
||||
self.__dict__["_fsm_info"] = None
|
||||
|
||||
def copy(self):
|
||||
return BetterFSM(
|
||||
alphabet=self.alphabet.copy(),
|
||||
states=self.states.copy(),
|
||||
initial=self.initial,
|
||||
finals=self.finals.copy(),
|
||||
map=self.map.copy(),
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def fsm_info(self):
|
||||
if self._fsm_info is None:
|
||||
flat_transition_map_items = np.fromiter(
|
||||
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
|
||||
dtype=np.dtype("i8, i8, i8"),
|
||||
)
|
||||
trans_key_to_states_items = np.fromiter(
|
||||
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
|
||||
dtype=np.dtype("i8, i8"),
|
||||
)
|
||||
alphabet_symbol_mapping_items = np.fromiter(
|
||||
(
|
||||
it
|
||||
for it in self.alphabet._symbol_mapping.items()
|
||||
if it[0] != anything_else
|
||||
),
|
||||
dtype=np.dtype("U1, i8"),
|
||||
)
|
||||
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
|
||||
self.__dict__["_fsm_info"] = create_fsm_info(
|
||||
self.initial,
|
||||
nb_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
self.alphabet.anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
)
|
||||
|
||||
return self._fsm_info
|
||||
|
||||
|
||||
nb_int_list_type = numba.types.ListType(numba.int64)
|
||||
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
|
||||
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
|
||||
|
||||
|
||||
@numba.njit(cache=True)
|
||||
def create_fsm_info(
|
||||
py_initial,
|
||||
py_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
py_anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
):
|
||||
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
|
||||
for trans_key_and_state in trans_key_to_states_items:
|
||||
trans_key_to_states.setdefault(
|
||||
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
|
||||
).append(trans_key_and_state[1])
|
||||
|
||||
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
|
||||
for trans_key_and_state in flat_transition_map_items:
|
||||
flat_transition_map[
|
||||
(trans_key_and_state[0], trans_key_and_state[1])
|
||||
] = trans_key_and_state[2]
|
||||
|
||||
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
|
||||
for symbol_and_trans_key in alphabet_symbol_mapping_items:
|
||||
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
|
||||
|
||||
initial = numba.int64(py_initial)
|
||||
|
||||
finals = set()
|
||||
for final in py_finals:
|
||||
finals.add(final)
|
||||
|
||||
anything_value = numba.int64(py_anything_value)
|
||||
|
||||
return FSMInfo(
|
||||
initial,
|
||||
finals,
|
||||
flat_transition_map,
|
||||
trans_key_to_states,
|
||||
anything_value,
|
||||
alphabet_symbol_map,
|
||||
)
|
||||
|
||||
|
||||
FSMInfo = namedtuple(
|
||||
"FSMInfo",
|
||||
[
|
||||
"initial",
|
||||
"finals",
|
||||
"transitions",
|
||||
"trans_key_to_states",
|
||||
"alphabet_anything_value",
|
||||
"alphabet_symbol_mapping",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
|
||||
"""Construct an equivalent FSM with deterministic state labels."""
|
||||
old_to_new_trans_keys = {
|
||||
trans_key: i
|
||||
for i, (trans_key, _) in enumerate(
|
||||
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
|
||||
)
|
||||
}
|
||||
|
||||
new_symbol_mapping = {
|
||||
symbol: old_to_new_trans_keys[trans_key]
|
||||
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
|
||||
}
|
||||
|
||||
new_alphabet = BetterAlphabet(new_symbol_mapping)
|
||||
|
||||
new_map = {
|
||||
from_state: {
|
||||
old_to_new_trans_keys[trans_key]: to_state
|
||||
for trans_key, to_state in trans_map.items()
|
||||
}
|
||||
for from_state, trans_map in fsm.map.items()
|
||||
}
|
||||
|
||||
old_to_new_states = {}
|
||||
old_to_new_states[fsm.initial] = 0
|
||||
|
||||
i = 0
|
||||
seen = {fsm.initial}
|
||||
old_state_queue = [fsm.initial]
|
||||
while old_state_queue:
|
||||
old_state = old_state_queue.pop(-1)
|
||||
transitions = new_map[old_state]
|
||||
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
|
||||
for _, old_state in sorted_transitions:
|
||||
if old_state not in seen:
|
||||
old_state_queue.append(old_state)
|
||||
seen.add(old_state)
|
||||
if old_state not in old_to_new_states:
|
||||
i += 1
|
||||
old_to_new_states[old_state] = i
|
||||
|
||||
new_map = dict(
|
||||
sorted(
|
||||
(
|
||||
(
|
||||
old_to_new_states[from_state],
|
||||
dict(
|
||||
sorted(
|
||||
(
|
||||
(trans_key, old_to_new_states[to_state])
|
||||
for trans_key, to_state in trans_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
),
|
||||
)
|
||||
for from_state, trans_map in new_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
)
|
||||
|
||||
new_initial = 0
|
||||
new_finals = frozenset(
|
||||
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
|
||||
)
|
||||
new_states = frozenset(sorted(new_map.keys()))
|
||||
|
||||
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
|
||||
|
||||
return new_fsm, old_to_new_states
|
||||
|
||||
|
||||
@numba.njit(nogil=True, cache=True)
|
||||
def _walk_fsm(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
state = start_state
|
||||
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
|
||||
last_final_idx: int = numba.uint64(0)
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = numba.uint64(i + 1)
|
||||
|
||||
accepted_states.append(_nonoptional(state))
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def walk_fsm(
|
||||
fsm: BetterFSM,
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
fsm_finals = fsm.finals
|
||||
|
||||
state = start_state
|
||||
accepted_states: List[int] = []
|
||||
last_final_idx: int = 0
|
||||
|
||||
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
|
||||
alphabet_anything_value = fsm.alphabet.anything_value
|
||||
fsm_transitions = fsm.flat_transition_map
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return []
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = i + 1
|
||||
|
||||
accepted_states.append(state)
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return []
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def fsm_union(
|
||||
fsms: Sequence[FSM],
|
||||
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
|
||||
"""Construct an FSM representing the union of the FSMs in `fsms`.
|
||||
|
||||
This is an updated version of `interegular.fsm.FSM.union` made to return an
|
||||
extra map of component FSMs to the sets of state transitions that
|
||||
correspond to them in the new FSM.
|
||||
|
||||
"""
|
||||
|
||||
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
|
||||
|
||||
indexed_fsms = tuple(enumerate(fsms))
|
||||
|
||||
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
|
||||
|
||||
# Dedicated function accepting a "superset" and returning the next
|
||||
# "superset" obtained by following this transition in the new FSM
|
||||
def follow(current_state, new_transition: int):
|
||||
next = {}
|
||||
for i, f in indexed_fsms:
|
||||
old_transition = new_to_old[i][new_transition]
|
||||
if (
|
||||
i in current_state
|
||||
and current_state[i] in f.map
|
||||
and old_transition in f.map[current_state[i]]
|
||||
):
|
||||
next[i] = f.map[current_state[i]][old_transition]
|
||||
if not next:
|
||||
raise OblivionError
|
||||
return next
|
||||
|
||||
states = [initial]
|
||||
finals: Set[int] = set()
|
||||
map: Dict[int, Dict[int, int]] = {}
|
||||
|
||||
# Map component FSMs to their new state-to-state transitions, finals, and a
|
||||
# map translating component FSM states to aggregate FSM states
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
] = {}
|
||||
|
||||
i = 0
|
||||
while i < len(states):
|
||||
state = states[i]
|
||||
|
||||
# Add to the finals of the aggregate FSM whenever we hit a final in a
|
||||
# component FSM
|
||||
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
|
||||
finals.add(i)
|
||||
|
||||
# Compute the map for this state
|
||||
map[i] = {}
|
||||
for transition in alphabet.by_transition:
|
||||
try:
|
||||
next = follow(state, transition)
|
||||
except OblivionError:
|
||||
# Reached an oblivion state; don't list it
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# TODO: Seems like this could--and should--be avoided
|
||||
j = states.index(next)
|
||||
except ValueError:
|
||||
j = len(states)
|
||||
states.append(next)
|
||||
|
||||
map[i][transition] = j
|
||||
|
||||
for fsm_id, fsm_state in next.items():
|
||||
(
|
||||
fsm_transitions,
|
||||
fsm_finals,
|
||||
fsm_old_to_new,
|
||||
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
|
||||
old_from = state[fsm_id]
|
||||
old_to = fsm_state
|
||||
fsm_old_to_new.setdefault(old_from, set()).add(i)
|
||||
fsm_old_to_new.setdefault(old_to, set()).add(j)
|
||||
fsm_transitions.add((i, j))
|
||||
if fsm_state in fsms[fsm_id].finals:
|
||||
fsm_finals.add(j)
|
||||
|
||||
i += 1
|
||||
|
||||
fsm = FSM(
|
||||
alphabet=alphabet,
|
||||
states=range(len(states)),
|
||||
initial=0,
|
||||
finals=finals,
|
||||
map=map,
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
fsm, old_to_new_states = make_deterministic_fsm(fsm)
|
||||
_fsms_to_trans_finals = {
|
||||
fsm_id: (
|
||||
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
|
||||
{old_to_new_states[s] for s in finals},
|
||||
{
|
||||
old_state: {old_to_new_states[new_state] for new_state in new_states}
|
||||
for old_state, new_states in old_to_new.items()
|
||||
},
|
||||
)
|
||||
for fsm_id, (transitions, finals, old_to_new) in sorted(
|
||||
fsms_to_trans_finals.items(), key=lambda x: x[0]
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
fsm,
|
||||
_fsms_to_trans_finals,
|
||||
)
|
||||
|
||||
|
||||
def get_sub_fsms_from_seq(
|
||||
state_seq: Sequence[int],
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
],
|
||||
) -> Generator[Tuple[int, bool, bool], None, None]:
|
||||
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_seq
|
||||
A state sequence.
|
||||
fsms_to_trans_finals
|
||||
A map from FSM indices to tuples containing sets of their state transitions
|
||||
and sets of the final/accept states.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator returning tuples containing each sub-FSM index (in the order
|
||||
they were union-ed to construct `fsm`) and booleans indicating whether or
|
||||
not there is another valid transition from the last state in the sequence
|
||||
for the associated sub-FSM (i.e. if the FSM can continue
|
||||
accepting/matching) and whether or not the sequence ends in a final state
|
||||
of the sub-FSM.
|
||||
"""
|
||||
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
|
||||
last_fsm_state = state_seq[-1]
|
||||
yield from (
|
||||
(
|
||||
# The sub-FMS index
|
||||
fsm_idx,
|
||||
# Is there another possible transition in this sub-FSM?
|
||||
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
|
||||
# Is this sub-FSM in a final state?
|
||||
state_seq[-1] in finals,
|
||||
)
|
||||
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
|
||||
if state_seq_transitions.issubset(transitions)
|
||||
)
|
||||
|
||||
|
||||
@numba.njit(cache=True, nogil=True)
|
||||
def state_scan_tokens(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
vocabulary: Dict[str, List[int]],
|
||||
start_state: int,
|
||||
) -> Set[Tuple[int, int]]:
|
||||
res = set()
|
||||
|
||||
for token, token_ids in vocabulary.items():
|
||||
state_seq = _walk_fsm(
|
||||
fsm_transitions,
|
||||
alphabet_symbol_mapping,
|
||||
alphabet_anything_value,
|
||||
fsm_initial,
|
||||
fsm_finals,
|
||||
token,
|
||||
start_state,
|
||||
False,
|
||||
)
|
||||
|
||||
if state_seq is not None and len(state_seq) < len(token):
|
||||
continue
|
||||
|
||||
for token_id in token_ids:
|
||||
res.add((token_id, state_seq[-1]))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def create_fsm_index_end_to_end(
|
||||
fsm_info: FSMInfo,
|
||||
vocabulary: Dict[str, List[int]],
|
||||
) -> Dict[int, Set[Tuple[int, int]]]:
|
||||
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
|
||||
|
||||
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
|
||||
# code, too.
|
||||
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
|
||||
seen: Set[int] = set()
|
||||
next_states = {fsm_info.initial}
|
||||
|
||||
while next_states:
|
||||
start_state = next_states.pop()
|
||||
|
||||
token_ids_end_states = state_scan_tokens(
|
||||
fsm_info.transitions,
|
||||
fsm_info.alphabet_symbol_mapping,
|
||||
fsm_info.alphabet_anything_value,
|
||||
fsm_info.initial,
|
||||
fsm_info.finals,
|
||||
vocabulary,
|
||||
start_state,
|
||||
)
|
||||
|
||||
for token_id_and_end_state in token_ids_end_states:
|
||||
states_to_token_subsets.setdefault(start_state, set()).add(
|
||||
token_id_and_end_state
|
||||
)
|
||||
end_state = token_id_and_end_state[1]
|
||||
if end_state not in seen:
|
||||
next_states.add(end_state)
|
||||
|
||||
seen.add(start_state)
|
||||
|
||||
return states_to_token_subsets
|
||||
|
||||
|
||||
# TODO: Cannot cache typed collections to disk, yet. See
|
||||
# https://github.com/numba/numba/issues/4698
|
||||
@lru_cache
|
||||
def reduced_vocabulary(tokenizer: "Tokenizer"):
|
||||
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
|
||||
vocabulary = numba.typed.Dict.empty(
|
||||
numba.types.string, numba.types.ListType(numba.int64)
|
||||
)
|
||||
empty_token_ids = set()
|
||||
for token, token_idx in tokenizer.vocabulary.items():
|
||||
if token in tokenizer.special_tokens:
|
||||
continue
|
||||
|
||||
token_str = tokenizer.convert_token_to_string(token)
|
||||
|
||||
if token_str:
|
||||
vocabulary.setdefault(
|
||||
token_str,
|
||||
numba.typed.List.empty_list(numba.int64),
|
||||
).append(numba.int64(token_idx))
|
||||
else:
|
||||
empty_token_ids.add(numba.int64(token_idx))
|
||||
|
||||
return vocabulary, empty_token_ids
|
||||
|
||||
|
||||
def create_fsm_index_tokenizer(
|
||||
fsm: BetterFSM,
|
||||
tokenizer: "Tokenizer",
|
||||
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
|
||||
"""Construct an FMS index from a tokenizer.
|
||||
|
||||
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
|
||||
|
||||
.. warning::
|
||||
|
||||
`fsm` needs to be deterministically ordered so that future caching makes sense.
|
||||
|
||||
"""
|
||||
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
|
||||
|
||||
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
|
||||
|
||||
# Allow transitions to EOS from all terminals FSM states that are
|
||||
# reachable
|
||||
# TODO: Do we really need this anymore?
|
||||
for state in fsm.fsm_info.finals:
|
||||
subset = states_to_token_subsets.get(state)
|
||||
if subset is not None:
|
||||
subset.add((tokenizer.eos_token_id, state))
|
||||
|
||||
# Convert to token-to-end-state maps
|
||||
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
|
||||
|
||||
return states_to_token_subsets, empty_token_ids
|
||||
266
python/sglang/srt/constrained/tokenizer.py
Normal file
266
python/sglang/srt/constrained/tokenizer.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class Tokenizer(Protocol, Hashable):
|
||||
eos_token: str
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
vocabulary: Dict[str, int]
|
||||
special_tokens: Set[int]
|
||||
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]]
|
||||
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
|
||||
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
|
||||
"""Translate an array of token ids to a string or list of strings."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
"""Convert a token to its equivalent string.
|
||||
|
||||
This is for instance useful for BPE tokenizers where whitespaces are
|
||||
represented by the special characted `Ġ`. This prevents matching a raw
|
||||
token that includes `Ġ` with a string.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
__all__ = ["transformers"]
|
||||
|
||||
|
||||
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
|
||||
|
||||
|
||||
def get_llama_tokenizer_types():
|
||||
"""Get all the Llama tokenizer types/classes that need work-arounds.
|
||||
|
||||
When they can't be imported, a dummy class is created.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
return (
|
||||
LlamaTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
CodeLlamaTokenizer,
|
||||
CodeLlamaTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
class Transformer:
|
||||
"""Represents a `transformers` model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
):
|
||||
self.device = model.device
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@torch.inference_mode
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
|
||||
"""Compute a forward pass through the transformer model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_ids
|
||||
The input token ids. Must be one or two dimensional.
|
||||
attention_mask
|
||||
The attention mask. Must be one or two dimensional.
|
||||
past_key_values
|
||||
A tuple of tuples containing the cached key and value tensors for each
|
||||
attention head.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The computed logits and the new cached key and value tensors.
|
||||
|
||||
"""
|
||||
assert 0 < input_ids.ndim < 3
|
||||
|
||||
if past_key_values:
|
||||
input_ids = input_ids[..., -1].unsqueeze(-1)
|
||||
|
||||
output = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
return output.logits, output.past_key_values
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> torch.FloatTensor:
|
||||
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
|
||||
next_token_logits = logits[..., -1, :]
|
||||
|
||||
return next_token_logits, kv_cache
|
||||
|
||||
|
||||
class TransformerTokenizer(Tokenizer):
|
||||
"""Represents a tokenizer for models in the `transformers` library."""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
# TODO: Do something to make this hashable?
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
if not self.tokenizer.pad_token_id:
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
self.pad_token_id = self.eos_token_id
|
||||
else:
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.pad_token = self.tokenizer.pad_token
|
||||
|
||||
self.special_tokens = set(self.tokenizer.all_special_tokens)
|
||||
|
||||
self.vocabulary = self.tokenizer.get_vocab()
|
||||
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
|
||||
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]], **kwargs
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
||||
kwargs["padding"] = True
|
||||
kwargs["return_tensors"] = "pt"
|
||||
output = self.tokenizer(prompt, **kwargs)
|
||||
return output["input_ids"], output["attention_mask"]
|
||||
|
||||
def decode(self, token_ids: torch.LongTensor) -> List[str]:
|
||||
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return text
|
||||
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = self.tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
if self.is_llama:
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return False
|
||||
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
|
||||
# return other.model_name == self.model_name and other.kwargs == self.kwargs
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
return hash(Hasher.hash(self.tokenizer))
|
||||
|
||||
|
||||
def transformers(
|
||||
model_name: str,
|
||||
device: Optional[str] = None,
|
||||
model_kwargs: dict = {},
|
||||
tokenizer_kwargs: dict = {},
|
||||
):
|
||||
"""Instantiate a model from the `transformers` library and its tokenizer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name
|
||||
The name of the model as listed on Hugging Face's model page.
|
||||
device
|
||||
The device(s) on which the model should be loaded. This overrides
|
||||
the `device_map` entry in `model_kwargs` when provided.
|
||||
model_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the model.
|
||||
tokenizer_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the tokenizer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A `TransformersModel` model instance.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `transformers` library needs to be installed in order to use `transformers` models."
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model_kwargs["device_map"] = device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
||||
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
|
||||
|
||||
return Transformer(model, tokenizer)
|
||||
164
python/sglang/srt/hf_transformers_utils.py
Normal file
164
python/sglang/srt/hf_transformers_utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Utilities for Huggingface Transformers."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
def download_from_hf(model_path: str):
|
||||
if os.path.exists(model_path):
|
||||
return model_path
|
||||
|
||||
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
||||
|
||||
|
||||
def get_config_json(model_path: str):
|
||||
with open(os.path.join(model_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
|
||||
def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
# Models don't use the same configuration key for determining the maximum
|
||||
# context length. Store them here so we can sanely check them.
|
||||
# NOTE: The ordering here is important. Some models have two of these and we
|
||||
# have a preference for which value gets used.
|
||||
CONTEXT_LENGTH_KEYS = [
|
||||
"max_sequence_length",
|
||||
"seq_length",
|
||||
"max_position_embeddings",
|
||||
"max_seq_len",
|
||||
"model_max_length",
|
||||
]
|
||||
|
||||
|
||||
def get_context_length(config):
|
||||
"""Get the context length of a model from a huggingface model config."""
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = config.rope_scaling["factor"]
|
||||
else:
|
||||
rope_scaling_factor = 1
|
||||
|
||||
for key in CONTEXT_LENGTH_KEYS:
|
||||
val = getattr(config, key, None)
|
||||
if val is not None:
|
||||
return int(rope_scaling_factor * val)
|
||||
return 2048
|
||||
|
||||
|
||||
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
|
||||
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
if is_multimodal_model(tokenizer_name):
|
||||
processor = get_processor(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
tokenizer = processor.tokenizer
|
||||
return tokenizer
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
if (
|
||||
"llama" in tokenizer_name.lower()
|
||||
and kwargs.get("use_fast", True)
|
||||
and tokenizer_name != _FAST_LLAMA_TOKENIZER
|
||||
):
|
||||
pass
|
||||
# warnings.warn(
|
||||
# "For some LLaMA V1 models, initializing the fast tokenizer may "
|
||||
# "take a long time. To reduce the initialization time, consider "
|
||||
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
# "tokenizer."
|
||||
# )
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
except TypeError as e:
|
||||
# The LLaMA tokenizer causes a protobuf error in some environments.
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
|
||||
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
|
||||
"original tokenizer."
|
||||
)
|
||||
raise RuntimeError(err_msg) from e
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
if not trust_remote_code and (
|
||||
"does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e)
|
||||
):
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If the tokenizer is a custom "
|
||||
"tokenizer not yet available in the HuggingFace transformers "
|
||||
"library, consider setting `trust_remote_code=True` in LLM "
|
||||
"or using the `--trust-remote-code` flag in the CLI."
|
||||
)
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
warnings.warn(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead."
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_processor(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
**kwargs,
|
||||
)
|
||||
return processor
|
||||
181
python/sglang/srt/layers/context_flashattention_nopad.py
Normal file
181
python/sglang/srt/layers/context_flashattention_nopad.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Out,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
||||
|
||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
||||
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(
|
||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(
|
||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
BLOCK = 128
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
num_warps,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||
371
python/sglang/srt/layers/extend_attention.py
Normal file
371
python/sglang/srt/layers/extend_attention.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q_Extend,
|
||||
K_Extend,
|
||||
V_Extend,
|
||||
O_Extend,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Seq_Len,
|
||||
B_Start_Loc_Extend,
|
||||
B_Seq_Len_Extend,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_req_to_tokens_b,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_seq = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
cur_block_m = tl.program_id(2)
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_seq_len = tl.load(B_Seq_Len + cur_seq)
|
||||
cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
|
||||
cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
|
||||
|
||||
cur_seq_prefix_start_in_loc = 0
|
||||
cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
||||
offs_q = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
|
||||
|
||||
# stage1: compute scores with prefix
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
|
||||
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_seq_len_prefix
|
||||
offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
|
||||
cur_seq_prefix_start_in_loc + start_n + offs_n
|
||||
)
|
||||
offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
|
||||
|
||||
# load k in transposed way
|
||||
offs_buf_k = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_buf_v = (
|
||||
offs_kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
# stage2: compute the trianlge part
|
||||
|
||||
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
mask_n = (start_n + offs_n) < cur_block_m_end
|
||||
|
||||
# load k in transposed way
|
||||
offs_k = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_d[:, None]
|
||||
)
|
||||
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||
start_n + offs_n[None, :]
|
||||
)
|
||||
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||
qk = tl.where(mask_causual, qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
deno = deno * re_scale + tl.sum(p, 1)
|
||||
|
||||
offs_v = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
offs_o = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
||||
|
||||
|
||||
def extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
|
||||
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||
"""
|
||||
BLOCK_M, BLOCK_N = 128, 128
|
||||
Lq, Lk, Lv, Lo = (
|
||||
q_extend.shape[-1],
|
||||
k_extend.shape[-1],
|
||||
v_extend.shape[-1],
|
||||
o_extend.shape[-1],
|
||||
)
|
||||
assert Lq == Lk and Lk == Lv and Lv == Lo
|
||||
assert Lq in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||
|
||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
num_stages = 1
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
q_extend.stride(1),
|
||||
k_extend.stride(0),
|
||||
k_extend.stride(1),
|
||||
v_extend.stride(0),
|
||||
v_extend.stride(1),
|
||||
o_extend.stride(0),
|
||||
o_extend.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
BLOCK_DMODEL=Lq,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
|
||||
def redundant_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
):
|
||||
total_token_num = k_buffer.shape[0]
|
||||
B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1]
|
||||
q_buffer = torch.empty(
|
||||
(total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device
|
||||
)
|
||||
|
||||
pt = 0
|
||||
for i in range(B):
|
||||
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
|
||||
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
||||
q_buffer[pl:pr] = q_extend[pt : pt + cur_seq_len_extend]
|
||||
pt += cur_seq_len_extend
|
||||
|
||||
o_buffer = torch.empty_like(q_buffer)
|
||||
context_attention_fwd(
|
||||
q_buffer, k_buffer, v_buffer, o_buffer, b_start_loc, b_seq_len, max_len_in_batch
|
||||
)
|
||||
|
||||
pt = 0
|
||||
for i in range(B):
|
||||
cur_seq_len_extend = b_seq_len[i] - b_seq_len_prefix[i]
|
||||
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
||||
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
|
||||
pt += cur_seq_len_extend
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(0)
|
||||
|
||||
B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
|
||||
dtype = torch.float16
|
||||
|
||||
b_seq_len_prefix = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len_extend = torch.randint(
|
||||
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||
|
||||
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
||||
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
for i in range(B):
|
||||
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||
k_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
v_buffer = torch.empty(
|
||||
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.empty(
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
|
||||
).normal_(mean=0.1, std=0.2)
|
||||
|
||||
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
redundant_attention(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
|
||||
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
|
||||
|
||||
assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
79
python/sglang/srt/layers/get_selected_logprob.py
Normal file
79
python/sglang/srt/layers/get_selected_logprob.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_segmented_gather(
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
voc_size: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
cur_req = tl.program_id(0)
|
||||
cur_l = tl.load(len_add_1 + cur_req)
|
||||
cum_l = tl.load(cum_len + cur_req)
|
||||
|
||||
for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE):
|
||||
off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = off < cur_l - 1
|
||||
|
||||
idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask)
|
||||
data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask)
|
||||
tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask)
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs):
|
||||
cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0)
|
||||
voc_size = all_logits.shape[1]
|
||||
grid = (len_add_1.shape[0], 1, 1)
|
||||
max_seq_len = len_add_1.max().item()
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
4,
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_segmented_gather[grid](
|
||||
all_logits,
|
||||
len_add_1,
|
||||
cum_len,
|
||||
input_ids,
|
||||
logprobs,
|
||||
max_seq_len,
|
||||
voc_size,
|
||||
BLOCK_SIZE=128,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_logits = torch.tensor(
|
||||
# s s s
|
||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
|
||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
||||
logprobs = torch.empty((3), dtype=torch.float32, device="cuda")
|
||||
get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs)
|
||||
print(logprobs)
|
||||
# assert logprobs == [2, 2, 4]
|
||||
77
python/sglang/srt/layers/logits_processor.py
Normal file
77
python/sglang/srt/layers/logits_processor.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from sglang.srt.layers.get_selected_logprob import get_selected_logprob
|
||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
from torch import nn
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
||||
if not input_metadata.return_normalized_logprob:
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
last_hidden = hidden_states[last_index]
|
||||
hidden_states = None
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
return last_logits, None
|
||||
else:
|
||||
assert input_metadata.forward_mode != ForwardMode.DECODE
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
logits = logits[:, : self.config.vocab_size]
|
||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
||||
|
||||
normalized_logprobs = compute_normalized_logprobs(
|
||||
all_logprobs,
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
input_ids,
|
||||
)
|
||||
|
||||
last_logits = logits[last_index]
|
||||
return last_logits, normalized_logprobs
|
||||
|
||||
|
||||
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
|
||||
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
|
||||
logprobs = torch.zeros(
|
||||
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs)
|
||||
cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
end = torch.cumsum(len_add_1.sub_(1), dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
|
||||
end.sub_(1)
|
||||
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
|
||||
res = sum_logp / len_add_1
|
||||
return res
|
||||
158
python/sglang/srt/layers/radix_attention.py
Normal file
158
python/sglang/srt/layers/radix_attention.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
from torch import nn
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
head_dim,
|
||||
scaling,
|
||||
num_kv_heads,
|
||||
layer_id,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tp_q_head_num = num_heads
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
self.tp_v_head_num = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.layer_id = layer_id
|
||||
|
||||
from sglang.srt.managers.router.model_runner import global_model_mode
|
||||
|
||||
self.use_flashinfer = "flashinfer" in global_model_mode
|
||||
|
||||
if self.use_flashinfer:
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
else:
|
||||
self.prefill_forward = self.prefill_forward_triton
|
||||
self.extend_forward = self.extend_forward_triton
|
||||
self.decode_forward = self.decode_forward_triton
|
||||
|
||||
def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
|
||||
context_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k,
|
||||
v,
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
return o
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
extend_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.prefix_lens,
|
||||
input_metadata.extend_start_loc,
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.max_extend_len,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
token_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.other_kv_index,
|
||||
input_metadata.total_num_tokens,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.prefill_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.qo_indptr,
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
input_metadata.kv_indptr,
|
||||
input_metadata.kv_indices,
|
||||
input_metadata.kv_last_page_len,
|
||||
allow_fp16_qk_reduction=True,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
o = input_metadata.decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
input_metadata.kv_indptr,
|
||||
input_metadata.kv_indices,
|
||||
input_metadata.kv_last_page_len,
|
||||
)
|
||||
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.PREFILL:
|
||||
return self.prefill_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
||||
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
||||
if input_metadata.out_cache_loc is not None:
|
||||
key_buffer[input_metadata.out_cache_loc] = cache_k
|
||||
value_buffer[input_metadata.out_cache_loc] = cache_v
|
||||
elif input_metadata.out_cache_cont_start is not None:
|
||||
key_buffer[
|
||||
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
||||
] = cache_k
|
||||
value_buffer[
|
||||
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
||||
] = cache_v
|
||||
else:
|
||||
raise RuntimeError()
|
||||
324
python/sglang/srt/layers/token_attention.py
Normal file
324
python/sglang/srt/layers/token_attention.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
att_stride_h,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
||||
|
||||
cur_batch_start_index = 0
|
||||
cur_batch_end_index = cur_batch_seq_len
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
block_stard_index = start_n * BLOCK_N
|
||||
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
q = tl.load(Q + off_q + start_mark)
|
||||
offs_n_new = cur_batch_start_index + offs_n
|
||||
k_loc = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
||||
mask=offs_n_new < cur_batch_end_index,
|
||||
other=0,
|
||||
)
|
||||
offs_buf_k = (
|
||||
k_loc[:, None] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_d[None, :]
|
||||
)
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=offs_n_new[:, None] < cur_batch_end_index,
|
||||
other=0.0,
|
||||
)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
||||
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage2(
|
||||
Logics,
|
||||
V_Buffer,
|
||||
Out,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
stride_logic_h,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_req_to_token_b,
|
||||
other_kv_index, # To fix a NAN issue
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
|
||||
v_ptrs = V_Buffer + offs_buf_v
|
||||
|
||||
e_max = float("-inf")
|
||||
e_sum = 0.0
|
||||
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
v_index = tl.load(
|
||||
Req_to_tokens
|
||||
+ cur_batch_req_idx * stride_req_to_token_b
|
||||
+ (start_n + offs_n),
|
||||
mask=(start_n + offs_n) < cur_batch_seq_len,
|
||||
other=other_kv_index,
|
||||
)
|
||||
|
||||
qk = tl.load(
|
||||
Logics
|
||||
+ cur_head * stride_logic_h
|
||||
+ (cur_batch_start_loc + start_n + offs_n),
|
||||
mask=start_n + offs_n < cur_batch_seq_len,
|
||||
other=float("-inf"),
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max)
|
||||
e_sum = e_sum * old_scale + tl.sum(p, 0)
|
||||
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
|
||||
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
|
||||
e_max = n_e_max
|
||||
|
||||
acc = acc / e_sum
|
||||
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
cached_kernel_stage1 = None
|
||||
cached_kernel_stage2 = None
|
||||
|
||||
|
||||
def _token_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_len_in_batch,
|
||||
):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
assert Lq == Lk
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
sm_scale = 1.0 / (Lk**0.5)
|
||||
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||
|
||||
if kv_group_num == 1:
|
||||
num_warps = 4
|
||||
else:
|
||||
num_warps = 2
|
||||
|
||||
global cached_kernel_stage1
|
||||
if cached_kernel_stage1:
|
||||
cached_kernel_stage1(
|
||||
grid,
|
||||
num_warps,
|
||||
q,
|
||||
k_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
att_out.stride(0),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
att_out.stride(0),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
|
||||
|
||||
|
||||
def _token_softmax_reducev_fwd(
|
||||
logics,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
other_kv_index,
|
||||
):
|
||||
BLOCK = 64
|
||||
batch, head = b_seq_len.shape[0], logics.shape[0]
|
||||
grid = (batch, head, 1)
|
||||
kv_group_num = logics.shape[0] // v_buffer.shape[1]
|
||||
|
||||
num_warps = 1
|
||||
|
||||
global cached_kernel_stage2
|
||||
if cached_kernel_stage2:
|
||||
cached_kernel_stage2(
|
||||
grid,
|
||||
num_warps,
|
||||
logics,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
logics.stride(0),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
other_kv_index,
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel_stage2[grid](
|
||||
logics,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
logics.stride(0),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
other_kv_index,
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=v_buffer.shape[-1],
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=3,
|
||||
)
|
||||
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
|
||||
|
||||
|
||||
def token_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_num_tokens,
|
||||
att_m=None,
|
||||
):
|
||||
if att_m is None:
|
||||
att_m = torch.empty(
|
||||
(q.shape[-2], total_num_tokens), dtype=q.dtype, device="cuda"
|
||||
)
|
||||
|
||||
_token_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
att_m,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
)
|
||||
_token_softmax_reducev_fwd(
|
||||
att_m,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
other_kv_index,
|
||||
)
|
||||
85
python/sglang/srt/managers/detokenizer_manager.py
Normal file
85
python/sglang/srt/managers/detokenizer_manager.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
class DetokenizerManager:
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_router = context.socket(zmq.PULL)
|
||||
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
|
||||
|
||||
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj = await self.recv_from_router.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, BatchTokenIDOut):
|
||||
output_tokens = recv_obj.output_tokens
|
||||
|
||||
# TODO(lmzheng): handle skip_special_tokens per request
|
||||
output_strs = self.tokenizer.batch_decode(
|
||||
output_tokens,
|
||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||
)
|
||||
|
||||
# Trim stop str
|
||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
for i in range(len(output_strs)):
|
||||
if recv_obj.hit_stop_str[i] is not None:
|
||||
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
|
||||
if pos != -1:
|
||||
output_strs[i] = output_strs[i][:pos]
|
||||
|
||||
if len(output_tokens[i]) > 0:
|
||||
first_token = self.tokenizer.convert_ids_to_tokens(
|
||||
int(output_tokens[i][0])
|
||||
)
|
||||
if first_token.startswith("▁"):
|
||||
output_strs[i] = " " + output_strs[i]
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchStrOut(
|
||||
recv_obj.rids,
|
||||
output_strs,
|
||||
recv_obj.meta_info,
|
||||
recv_obj.finished,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
|
||||
|
||||
def start_detokenizer_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
try:
|
||||
manager = DetokenizerManager(server_args, port_args)
|
||||
except Exception as e:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
pipe_writer.send("init ok")
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(manager.handle_loop())
|
||||
88
python/sglang/srt/managers/io_struct.py
Normal file
88
python/sglang/srt/managers/io_struct.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateReqInput:
|
||||
text: Union[List[str], str]
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
|
||||
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
stream: bool = False
|
||||
|
||||
def post_init(self):
|
||||
is_single = isinstance(self.text, str)
|
||||
|
||||
if is_single:
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {}
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.return_normalized_logprob is None:
|
||||
self.return_normalized_logprob = False
|
||||
if self.normalized_logprob_start_len is None:
|
||||
self.normalized_logprob_start_len = 0
|
||||
else:
|
||||
num = len(self.text)
|
||||
|
||||
if self.image_data is None:
|
||||
self.image_data = [None] * num
|
||||
elif not isinstance(self.image_data, list):
|
||||
self.image_data = [self.image_data] * num
|
||||
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [{}] * num
|
||||
elif not isinstance(self.sampling_params, list):
|
||||
self.sampling_params = [self.sampling_params] * num
|
||||
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||
else:
|
||||
assert isinstance(self.rid, list)
|
||||
|
||||
if self.return_normalized_logprob is None:
|
||||
self.return_normalized_logprob = [False] * num
|
||||
elif not isinstance(self.return_normalized_logprob, list):
|
||||
self.return_normalized_logprob = [self.return_normalized_logprob] * num
|
||||
|
||||
if self.normalized_logprob_start_len is None:
|
||||
self.normalized_logprob_start_len = [0] * num
|
||||
elif not isinstance(self.normalized_logprob_start_len, list):
|
||||
self.normalized_logprob_start_len = [
|
||||
self.normalized_logprob_start_len
|
||||
] * num
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizedGenerateReqInput:
|
||||
rid: str
|
||||
input_ids: List[int]
|
||||
pixel_values: List[float]
|
||||
image_hash: int
|
||||
sampling_params: SamplingParams
|
||||
return_normalized_logprob: bool
|
||||
normalized_logprob_start_len: int
|
||||
stream: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenIDOut:
|
||||
rids: List[str]
|
||||
output_tokens: List[List[int]]
|
||||
hit_stop_str: List[Optional[str]]
|
||||
skip_special_tokens: List[bool]
|
||||
meta_info: List[Dict]
|
||||
finished: List[bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStrOut:
|
||||
rids: List[str]
|
||||
output_str: List[str]
|
||||
meta_info: List[Dict]
|
||||
finished: List[bool]
|
||||
12
python/sglang/srt/managers/openai_protocol.py
Normal file
12
python/sglang/srt/managers/openai_protocol.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionRequest:
|
||||
prompt: Union[str, List[Any]]
|
||||
model: str = "default"
|
||||
temperature: Optional[float] = 0.7
|
||||
max_tokens: Optional[int] = 16
|
||||
n: Optional[int] = 1
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
326
python/sglang/srt/managers/router/infer_batch.py
Normal file
326
python/sglang/srt/managers/router/infer_batch.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
|
||||
|
||||
class ForwardMode(Enum):
|
||||
PREFILL = auto()
|
||||
EXTEND = auto()
|
||||
DECODE = auto()
|
||||
|
||||
|
||||
class FinishReason(Enum):
|
||||
LENGTH = auto()
|
||||
EOS_TOKEN = auto()
|
||||
STOP_STR = auto()
|
||||
|
||||
|
||||
class Req:
|
||||
def __init__(self, rid):
|
||||
self.rid = rid
|
||||
self.input_ids = []
|
||||
self.output_ids = []
|
||||
self.pixel_values = None
|
||||
self.image_offset = 0
|
||||
self.sampling_params = None
|
||||
self.return_normalized_logprob = False
|
||||
self.normalized_logprob_start_len = 0
|
||||
self.stream = False
|
||||
|
||||
self.tokenizer = None
|
||||
self.finished = False
|
||||
self.finish_reason = None
|
||||
self.hit_stop_str = None
|
||||
|
||||
self.adjust_input_len = 0
|
||||
self.prefix_indices = []
|
||||
|
||||
self.normalized_logprob = None
|
||||
|
||||
# for constrained decoding
|
||||
self.regex_fsm = None
|
||||
self.regex_fsm_state = None
|
||||
|
||||
def max_new_tokens(self):
|
||||
return self.sampling_params.max_new_tokens
|
||||
|
||||
def check_finished(self):
|
||||
if self.finished:
|
||||
return
|
||||
|
||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||
self.finished = True
|
||||
self.finish_reason = FinishReason.LENGTH
|
||||
return
|
||||
|
||||
if (
|
||||
self.output_ids[-1] == self.tokenizer.eos_token_id
|
||||
and self.sampling_params.ignore_eos == False
|
||||
):
|
||||
self.finished = True
|
||||
self.finish_reason = FinishReason.EOS_TOKEN
|
||||
return
|
||||
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
tail_str = self.tokenizer.decode(
|
||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
||||
)
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str:
|
||||
self.finished = True
|
||||
self.finish_reason = FinishReason.STOP_STR
|
||||
self.hit_stop_str = stop_str
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
|
||||
|
||||
|
||||
class Batch:
|
||||
def __init__(
|
||||
self,
|
||||
reqs: List[Req],
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: TokenToKVPool,
|
||||
tree_cache: RadixCache,
|
||||
):
|
||||
self.reqs = reqs
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
self.return_normalized_logprob = any(
|
||||
req.return_normalized_logprob for req in reqs
|
||||
)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
||||
device = "cuda"
|
||||
bs = len(self.reqs)
|
||||
reqs = self.reqs
|
||||
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
prefix_indices = [r.prefix_indices for r in reqs]
|
||||
|
||||
# Handle prefix
|
||||
flatten_input_ids = []
|
||||
extend_lens = []
|
||||
prefix_lens = []
|
||||
seq_lens = []
|
||||
|
||||
req_pool_indices = self.req_to_token_pool.alloc(bs)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
for i in range(bs):
|
||||
flatten_input_ids.extend(input_ids[i])
|
||||
extend_lens.append(len(input_ids[i]))
|
||||
|
||||
if len(prefix_indices[i]) == 0:
|
||||
prefix_lens.append(0)
|
||||
else:
|
||||
prefix_lens.append(len(prefix_indices[i]))
|
||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
||||
: len(prefix_indices[i])
|
||||
] = prefix_indices[i]
|
||||
|
||||
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
|
||||
|
||||
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
|
||||
|
||||
# Alloc mem
|
||||
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
|
||||
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
if out_cache_loc is None:
|
||||
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
print("Prefill out of memory.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
|
||||
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
|
||||
] = out_cache_loc[pt : pt + extend_lens[i]]
|
||||
pt += extend_lens[i]
|
||||
|
||||
# Handle logit bias
|
||||
logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
|
||||
for i in range(bs):
|
||||
if reqs[i].sampling_params.dtype == "int":
|
||||
logit_bias[i] = int_token_logit_bias
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(
|
||||
flatten_input_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_offsets = [
|
||||
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
|
||||
]
|
||||
self.req_pool_indices = req_pool_indices
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
||||
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||
self.position_ids_offsets = position_ids_offsets
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
self.temperatures = torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
).view(-1, 1)
|
||||
self.top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
||||
).view(-1, 1)
|
||||
self.top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
||||
).view(-1, 1)
|
||||
self.frequency_penalties = torch.tensor(
|
||||
[r.sampling_params.frequency_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.presence_penalties = torch.tensor(
|
||||
[r.sampling_params.presence_penalty for r in reqs],
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
self.logit_bias = logit_bias
|
||||
|
||||
def update_for_decode(self, input_ids=None):
|
||||
if input_ids is None:
|
||||
input_ids = [
|
||||
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
|
||||
]
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens.add_(1)
|
||||
self.prefix_lens = None
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
|
||||
if alloc_res is None:
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
|
||||
|
||||
if self.out_cache_loc is None:
|
||||
print("Decode out of memory.")
|
||||
self.tree_cache.pretty_print()
|
||||
exit()
|
||||
|
||||
self.out_cache_cont_start = None
|
||||
self.out_cache_cont_end = None
|
||||
else:
|
||||
self.out_cache_loc = alloc_res[0]
|
||||
self.out_cache_cont_start = alloc_res[1]
|
||||
self.out_cache_cont_end = alloc_res[2]
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens - 1
|
||||
] = self.out_cache_loc
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int]):
|
||||
self.reqs = [self.reqs[i] for i in unfinished_indices]
|
||||
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.input_ids = None
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
setattr(self, item, getattr(self, item)[new_indices])
|
||||
|
||||
def merge(self, other):
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
|
||||
self.prefix_lens = None
|
||||
self.position_ids_offsets = torch.concat(
|
||||
[self.position_ids_offsets, other.position_ids_offsets]
|
||||
)
|
||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"frequency_penalties",
|
||||
"presence_penalties",
|
||||
"logit_bias",
|
||||
]:
|
||||
setattr(
|
||||
self, item, torch.concat([getattr(self, item), getattr(other, item)])
|
||||
)
|
||||
|
||||
def sample(self, logits: torch.Tensor):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
logits.div_(self.temperatures)
|
||||
logits.add_(self.logit_bias)
|
||||
|
||||
has_regex = any(req.regex_fsm is not None for req in self.reqs)
|
||||
if has_regex:
|
||||
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
allowed_mask.zero_()
|
||||
allowed_mask[
|
||||
req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
|
||||
] = 1
|
||||
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
||||
|
||||
# TODO(lmzheng): apply penalty
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
||||
-1
|
||||
)
|
||||
batch_next_token_probs = torch.gather(
|
||||
probs_sort, dim=1, index=sampled_index
|
||||
).view(-1)
|
||||
|
||||
if has_regex:
|
||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.next_state(
|
||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||
)
|
||||
|
||||
return batch_next_token_ids, batch_next_token_probs
|
||||
|
||||
|
||||
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
|
||||
probs_sort[
|
||||
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
|
||||
] = 0.0
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
return probs_sort, probs_idx
|
||||
71
python/sglang/srt/managers/router/manager.py
Normal file
71
python/sglang/srt/managers/router/manager.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
class RouterManager:
|
||||
def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
|
||||
# Init communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||
)
|
||||
|
||||
# Init status
|
||||
self.model_client = model_client
|
||||
self.recv_reqs = []
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while True:
|
||||
next_step_input = list(self.recv_reqs)
|
||||
self.recv_reqs = []
|
||||
out_pyobjs = await self.model_client.step(next_step_input)
|
||||
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# await for a while to accept input requests
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
while True:
|
||||
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
||||
self.recv_reqs.append(recv_req)
|
||||
|
||||
|
||||
def start_router_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
model_client = ModelRpcClient(server_args, port_args)
|
||||
router = RouterManager(model_client, port_args)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
pipe_writer.send("init ok")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(router.loop_for_recv_requests())
|
||||
loop.run_until_complete(router.loop_for_forward())
|
||||
497
python/sglang/srt/managers/router/model_rpc.py
Normal file
497
python/sglang/srt/managers/router/model_rpc.py
Normal file
@@ -0,0 +1,497 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import rpyc
|
||||
import torch
|
||||
from rpyc.utils.classic import obtain
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import BatchTokenIDOut, TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.router.scheduler import Scheduler
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_exception_traceback,
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("model_rpc")
|
||||
|
||||
|
||||
class ModelRpcServer(rpyc.Service):
|
||||
def exposed_init_model(
|
||||
self,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
||||
|
||||
# Copy arguments
|
||||
self.model_mode = server_args.model_mode
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path, server_args.trust_remote_code
|
||||
)
|
||||
self.model_runner = ModelRunner(
|
||||
self.model_config,
|
||||
server_args.mem_fraction_static,
|
||||
tp_rank,
|
||||
server_args.tp_size,
|
||||
port_args.nccl_port,
|
||||
server_args.load_format,
|
||||
server_args.trust_remote_code,
|
||||
server_args.model_mode,
|
||||
)
|
||||
if is_multimodal_model(server_args.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.max_total_num_token = self.model_runner.max_total_num_token
|
||||
self.max_num_running_seq = self.max_total_num_token // 2
|
||||
self.max_prefill_num_token = max(
|
||||
self.model_config.context_len, self.max_total_num_token // 6
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
set_random_seed(server_args.random_seed)
|
||||
logger.info(
|
||||
f"Rank {self.tp_rank}: "
|
||||
f"max_total_num_token={self.max_total_num_token}, "
|
||||
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
f"model_mode={self.model_mode}"
|
||||
)
|
||||
|
||||
# Init cache
|
||||
self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
|
||||
self.scheduler = Scheduler(
|
||||
self.schedule_heuristic,
|
||||
self.max_num_running_seq,
|
||||
self.max_prefill_num_token,
|
||||
self.max_total_num_token,
|
||||
self.tree_cache,
|
||||
)
|
||||
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
||||
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||
|
||||
# Init running status
|
||||
self.forward_queue: List[Req] = []
|
||||
self.running_batch: Batch = None
|
||||
self.out_pyobjs = []
|
||||
self.decode_forward_ct = 0
|
||||
self.stream_interval = 2
|
||||
|
||||
# Init the FSM cache for constrained generation
|
||||
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size != 1:
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
|
||||
try:
|
||||
# Recv requests
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
|
||||
# Forward
|
||||
self.forward_step()
|
||||
except Exception:
|
||||
logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
|
||||
|
||||
# Return results
|
||||
ret = self.out_pyobjs
|
||||
self.out_pyobjs = []
|
||||
return ret
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_step(self):
|
||||
new_batch = self.get_new_fill_batch()
|
||||
|
||||
if new_batch is not None:
|
||||
# Run new fill batch
|
||||
self.forward_fill_batch(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = new_batch
|
||||
else:
|
||||
self.running_batch.merge(new_batch)
|
||||
else:
|
||||
# Run decode batch
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(10):
|
||||
self.forward_decode_batch(self.running_batch)
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.running_batch is not None and self.tp_rank == 0:
|
||||
if self.decode_forward_ct >= 20:
|
||||
self.decode_forward_ct = 0
|
||||
num_used = self.max_total_num_token - (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
logger.info(
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
||||
f"#queue-req: {len(self.forward_queue)}"
|
||||
)
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(recv_req.rid)
|
||||
req.input_ids = recv_req.input_ids
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
pad_value = [
|
||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
||||
req.input_ids, pad_value
|
||||
)
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.return_normalized_logprob = recv_req.return_normalized_logprob
|
||||
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len
|
||||
req.stream = recv_req.stream
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# init the regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm_state = 0
|
||||
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
|
||||
|
||||
# Truncate long prompts
|
||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
req.sampling_params.max_new_tokens,
|
||||
self.model_config.context_len - 1 - len(req.input_ids),
|
||||
)
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_fill_batch(self):
|
||||
if (
|
||||
self.running_batch is not None
|
||||
and len(self.running_batch.reqs) > self.max_num_running_seq
|
||||
):
|
||||
return None
|
||||
|
||||
for req in self.forward_queue:
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||
if req.return_normalized_logprob:
|
||||
prefix_indices = prefix_indices[: req.normalized_logprob_start_len]
|
||||
req.adjust_input_len = len(req.input_ids) - len(prefix_indices)
|
||||
req.prefix_indices = prefix_indices
|
||||
req.last_node = last_node
|
||||
|
||||
# Get priority queue
|
||||
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
||||
|
||||
# Add requests if there is available space
|
||||
can_run_list = []
|
||||
new_batch_total_tokens = 0
|
||||
new_batch_input_tokens = 0
|
||||
new_batch_prefix_tokens = 0
|
||||
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
)
|
||||
new_ratio = self.scheduler.new_token_estimation_ratio()
|
||||
if self.running_batch:
|
||||
available_size -= sum(
|
||||
[
|
||||
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
|
||||
for r in self.running_batch.reqs
|
||||
]
|
||||
)
|
||||
|
||||
for req in self.forward_queue:
|
||||
if req.return_normalized_logprob:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
if req.adjust_input_len < 2:
|
||||
delta = 2 - req.adjust_input_len
|
||||
req.adjust_input_len += delta
|
||||
req.prefix_indices = req.prefix_indices[:-delta]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += delta
|
||||
if req.adjust_input_len == 0 and req.max_new_tokens() > 0:
|
||||
# Need at least one token to compute logits
|
||||
req.adjust_input_len = 1
|
||||
req.prefix_indices = req.prefix_indices[:-1]
|
||||
if req.image_offset is not None:
|
||||
req.image_offset += 1
|
||||
|
||||
if (
|
||||
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
< available_size
|
||||
and req.adjust_input_len + new_batch_input_tokens
|
||||
< self.max_prefill_num_token
|
||||
):
|
||||
delta = self.tree_cache.inc_ref_counter(req.last_node)
|
||||
available_size += delta
|
||||
|
||||
if not (
|
||||
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
< available_size
|
||||
):
|
||||
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
||||
available_size += delta
|
||||
else:
|
||||
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.adjust_input_len + req.max_new_tokens()
|
||||
)
|
||||
new_batch_input_tokens += req.adjust_input_len
|
||||
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
if self.tp_rank == 0:
|
||||
logger.info(
|
||||
f"new fill batch. #seq: {len(can_run_list)}. "
|
||||
f"#cached_token: {sum(len(x.prefix_indices) for x in can_run_list)}. "
|
||||
f"#new_token: {new_batch_input_tokens}. "
|
||||
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
|
||||
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
new_batch = Batch(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
)
|
||||
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
||||
return new_batch
|
||||
|
||||
def forward_fill_batch(self, batch: Batch):
|
||||
# Build batch tensors
|
||||
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, normalized_logprobs = self.model_runner.forward(
|
||||
batch, ForwardMode.EXTEND, batch.return_normalized_logprob
|
||||
)
|
||||
# print("extend logits", logits)
|
||||
if normalized_logprobs is not None:
|
||||
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
||||
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
normalized_logprobs = None
|
||||
|
||||
# Check finish condition
|
||||
reqs = batch.reqs
|
||||
for i in range(len(reqs)):
|
||||
reqs[i].output_ids = [next_token_ids[i]]
|
||||
reqs[i].check_finished()
|
||||
|
||||
if normalized_logprobs is not None:
|
||||
reqs[i].normalized_logprob = normalized_logprobs[i]
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def forward_decode_batch(self, batch: Batch):
|
||||
# Update batch tensors
|
||||
self.decode_forward_ct += 1
|
||||
batch.update_for_decode()
|
||||
|
||||
# Forward
|
||||
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
# Check finish condition
|
||||
reqs = batch.reqs
|
||||
for i in range(len(reqs)):
|
||||
reqs[i].output_ids.append(next_token_ids[i])
|
||||
reqs[i].check_finished()
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
output_rids = []
|
||||
output_tokens = []
|
||||
output_hit_stop_str = []
|
||||
output_skip_special_tokens = []
|
||||
output_meta_info = []
|
||||
output_finished = []
|
||||
finished_indices = []
|
||||
unfinished_indices = []
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.finished:
|
||||
finished_indices.append(i)
|
||||
else:
|
||||
unfinished_indices.append(i)
|
||||
|
||||
if req.finished or (
|
||||
req.stream and self.decode_forward_ct % self.stream_interval == 0
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_tokens.append(req.output_ids)
|
||||
output_hit_stop_str.append(req.hit_stop_str)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.input_ids),
|
||||
"completion_tokens": len(req.output_ids),
|
||||
}
|
||||
if req.return_normalized_logprob:
|
||||
meta_info["normalized_logprob"] = req.normalized_logprob
|
||||
output_meta_info.append(meta_info)
|
||||
output_finished.append(req.finished)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
self.out_pyobjs.append(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_tokens,
|
||||
output_hit_stop_str,
|
||||
output_skip_special_tokens,
|
||||
output_meta_info,
|
||||
output_finished,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished reqs
|
||||
if finished_indices:
|
||||
# Update radix cache
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
||||
for i in finished_indices:
|
||||
req = batch.reqs[i]
|
||||
req_pool_idx = req_pool_indices_cpu[i]
|
||||
token_ids = tuple(req.input_ids + req.output_ids)
|
||||
seq_len = len(token_ids) - 1
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
||||
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
|
||||
|
||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
|
||||
# Update batch tensors
|
||||
if unfinished_indices:
|
||||
batch.filter_batch(unfinished_indices)
|
||||
else:
|
||||
batch.reqs = []
|
||||
|
||||
|
||||
class ModelRpcClient:
|
||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
|
||||
tp_size = server_args.tp_size
|
||||
|
||||
if tp_size == 1:
|
||||
# Init model
|
||||
self.model_server = ModelRpcServer()
|
||||
self.model_server.exposed_init_model(0, server_args, port_args)
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(f):
|
||||
async def _func(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return _func
|
||||
|
||||
self.step = async_wrap(self.model_server.exposed_step)
|
||||
else:
|
||||
with ThreadPoolExecutor(tp_size) as executor:
|
||||
# Launch model processes
|
||||
rets = executor.map(start_model_process, port_args.model_rpc_ports)
|
||||
self.model_servers = [x[0] for x in rets]
|
||||
self.procs = [x[1] for x in rets]
|
||||
|
||||
# Init model
|
||||
def init_model(i):
|
||||
return self.model_servers[i].init_model(i, server_args, port_args)
|
||||
|
||||
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
|
||||
|
||||
# Wrap functions
|
||||
def async_wrap(func_name):
|
||||
fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
|
||||
|
||||
async def _func(*args, **kwargs):
|
||||
tasks = [f(*args, **kwargs) for f in fs]
|
||||
await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
|
||||
return obtain(tasks[0].value)
|
||||
|
||||
return _func
|
||||
|
||||
self.step = async_wrap("step")
|
||||
|
||||
|
||||
def start_model_process(port):
|
||||
def _init_service(port):
|
||||
t = ThreadedServer(
|
||||
ModelRpcServer(),
|
||||
port=port,
|
||||
protocol_config={"allow_pickle": True, "sync_request_timeout": 600},
|
||||
)
|
||||
t.start()
|
||||
|
||||
proc = multiprocessing.Process(target=_init_service, args=(port,))
|
||||
proc.start()
|
||||
time.sleep(1)
|
||||
|
||||
repeat_count = 0
|
||||
while repeat_count < 20:
|
||||
try:
|
||||
con = rpyc.connect(
|
||||
"localhost",
|
||||
port,
|
||||
config={"allow_pickle": True, "sync_request_timeout": 600},
|
||||
)
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
time.sleep(1)
|
||||
repeat_count += 1
|
||||
if repeat_count == 20:
|
||||
raise RuntimeError("init rpc env error!")
|
||||
|
||||
assert proc.is_alive()
|
||||
return con.root, proc
|
||||
458
python/sglang/srt/managers/router/model_runner.py
Normal file
458
python/sglang/srt/managers/router/model_runner.py
Normal file
@@ -0,0 +1,458 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
from sglang.utils import get_available_gpu_memory
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||
|
||||
# for model_mode
|
||||
global_model_mode: List[str] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
model_runner: "ModelRunner"
|
||||
forward_mode: ForwardMode
|
||||
batch_size: int
|
||||
total_num_tokens: int
|
||||
max_seq_len: int
|
||||
req_pool_indices: torch.Tensor
|
||||
start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
prefix_lens: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
|
||||
# for extend
|
||||
extend_seq_lens: torch.Tensor = None
|
||||
extend_start_loc: torch.Tensor = None
|
||||
max_extend_len: int = 0
|
||||
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
other_kv_index: torch.Tensor = None
|
||||
return_normalized_logprob: bool = False
|
||||
|
||||
# for flashinfer
|
||||
use_flashinfer: bool = False
|
||||
qo_indptr: torch.Tensor = None
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
kv_last_page_len: torch.Tensor = None
|
||||
prefill_wrapper = None
|
||||
decode_wrapper = None
|
||||
|
||||
def init_flashinfer_args(self, tp_size):
|
||||
self.kv_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
||||
self.kv_indices = torch.cat(
|
||||
[
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
|
||||
]
|
||||
for i in range(self.batch_size)
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
from flashinfer.ops import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
|
||||
if (
|
||||
self.forward_mode == ForwardMode.PREFILL
|
||||
or self.forward_mode == ForwardMode.EXTEND
|
||||
):
|
||||
self.qo_indptr = torch.zeros(
|
||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
||||
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
|
||||
self.prefill_wrapper.begin_forward(
|
||||
self.qo_indptr,
|
||||
self.batch_size,
|
||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
)
|
||||
else:
|
||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
|
||||
self.decode_wrapper.begin_forward(
|
||||
self.kv_indptr,
|
||||
self.kv_last_page_len,
|
||||
self.batch_size,
|
||||
self.model_runner.model_config.num_attention_heads // tp_size,
|
||||
self.model_runner.model_config.num_key_value_heads // tp_size,
|
||||
self.model_runner.model_config.head_dim,
|
||||
1,
|
||||
"NONE",
|
||||
"float16",
|
||||
)
|
||||
|
||||
def init_extend_args(self):
|
||||
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
||||
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0)
|
||||
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_runner,
|
||||
tp_size,
|
||||
forward_mode,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
out_cache_cont_start=None,
|
||||
out_cache_cont_end=None,
|
||||
return_normalized_logprob=False,
|
||||
):
|
||||
batch_size = len(req_pool_indices)
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
||||
total_num_tokens = int(torch.sum(seq_lens))
|
||||
max_seq_len = int(torch.max(seq_lens))
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
||||
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
||||
req_pool_indices[0], seq_lens[0] - 1
|
||||
].item()
|
||||
else:
|
||||
seq_lens_np = seq_lens.cpu().numpy()
|
||||
prefix_lens_np = prefix_lens.cpu().numpy()
|
||||
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
|
||||
positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(
|
||||
prefix_lens_np[i] + position_ids_offsets_np[i],
|
||||
seq_lens_np[i] + position_ids_offsets_np[i],
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
other_kv_index = None
|
||||
|
||||
ret = cls(
|
||||
model_runner=model_runner,
|
||||
forward_mode=forward_mode,
|
||||
batch_size=batch_size,
|
||||
total_num_tokens=total_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
req_pool_indices=req_pool_indices,
|
||||
start_loc=start_loc,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
positions=positions,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
other_kv_index=other_kv_index,
|
||||
)
|
||||
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
ret.init_extend_args()
|
||||
|
||||
ret.use_flashinfer = "flashinfer" in model_runner.model_mode
|
||||
if ret.use_flashinfer:
|
||||
ret.init_flashinfer_args(tp_size)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
mem_fraction_static,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
nccl_port,
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
model_mode: List[str] = (),
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.nccl_port = nccl_port
|
||||
self.load_format = load_format
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.model_mode = model_mode
|
||||
|
||||
global global_model_mode
|
||||
global_model_mode = model_mode
|
||||
|
||||
# Init torch distributed
|
||||
torch.cuda.set_device(self.tp_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.tp_size,
|
||||
rank=self.tp_rank,
|
||||
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
if self.tp_size > 1:
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
self.tp_rank, distributed=self.tp_size > 1
|
||||
) * (1 << 30)
|
||||
self.load_model()
|
||||
self.init_memory_pool(total_gpu_memory)
|
||||
|
||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||
|
||||
def load_model(self):
|
||||
"""See also vllm/model_executor/model_loader.py::get_model"""
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM
|
||||
from sglang.srt.models.mixtral import MixtralForCausalLM
|
||||
|
||||
# Select model class
|
||||
architectures = getattr(self.model_config.hf_config, "architectures", [])
|
||||
|
||||
model_class = None
|
||||
for arch in architectures:
|
||||
if arch == "LlamaForCausalLM":
|
||||
model_class = LlamaForCausalLM
|
||||
break
|
||||
if arch == "MistralForCausalLM":
|
||||
model_class = LlamaForCausalLM
|
||||
break
|
||||
if arch == "LlavaLlamaForCausalLM":
|
||||
model_class = LlavaLlamaForCausalLM
|
||||
break
|
||||
if arch == "MixtralForCausalLM":
|
||||
model_class = MixtralForCausalLM
|
||||
break
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported architectures: {architectures}")
|
||||
|
||||
# Load weights
|
||||
linear_method = None
|
||||
with _set_default_torch_dtype(torch.float16):
|
||||
with torch.device("cuda"):
|
||||
hf_quant_config = getattr(
|
||||
self.model_config.hf_config, "quantization_config", None
|
||||
)
|
||||
if hf_quant_config is not None:
|
||||
# TODO: config quantization awq etc
|
||||
quant_config = AWQConfig.from_config(hf_quant_config)
|
||||
print(f"quant_config: {quant_config}")
|
||||
linear_method = quant_config.get_linear_method()
|
||||
model = model_class(
|
||||
config=self.model_config.hf_config, linear_method=linear_method
|
||||
)
|
||||
model.load_weights(
|
||||
self.model_config.path,
|
||||
cache_dir=None,
|
||||
load_format=self.load_format,
|
||||
revision=None,
|
||||
)
|
||||
self.model = model
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.tp_rank, distributed=self.tp_size > 1
|
||||
) * (1 << 30)
|
||||
head_dim = (
|
||||
self.model_config.hidden_size // self.model_config.num_attention_heads
|
||||
)
|
||||
head_num = self.model_config.num_key_value_heads // self.tp_size
|
||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
max_num_token = int(rest_memory // cell_size)
|
||||
return max_num_token
|
||||
|
||||
def init_memory_pool(self, total_gpu_memory):
|
||||
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
int(self.max_total_num_token / self.model_config.context_len * 256),
|
||||
self.model_config.context_len + 8,
|
||||
)
|
||||
self.token_to_kv_pool = TokenToKVPool(
|
||||
self.max_total_num_token,
|
||||
dtype=torch.float16,
|
||||
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
||||
head_dim=self.model_config.hidden_size
|
||||
// self.model_config.num_attention_heads,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_prefill(
|
||||
self,
|
||||
input_ids,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
return_normalized_logprob,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.PREFILL,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
)
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend(
|
||||
self,
|
||||
input_ids,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
return_normalized_logprob,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
)
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(
|
||||
self,
|
||||
input_ids,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
out_cache_cont_start=out_cache_cont_start,
|
||||
out_cache_cont_end=out_cache_cont_end,
|
||||
)
|
||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
|
||||
0
|
||||
]
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(
|
||||
self,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_offsets,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
return_normalized_logprob,
|
||||
):
|
||||
input_metadata = InputMetadata.create(
|
||||
self,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
tp_size=self.tp_size,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
prefix_lens=prefix_lens,
|
||||
position_ids_offsets=position_ids_offsets,
|
||||
out_cache_loc=out_cache_loc,
|
||||
return_normalized_logprob=return_normalized_logprob,
|
||||
)
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
pixel_values,
|
||||
image_offsets,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False
|
||||
):
|
||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||
kwargs = {
|
||||
"input_ids": batch.input_ids,
|
||||
"pixel_values": batch.pixel_values,
|
||||
"image_offsets": batch.image_offsets,
|
||||
"req_pool_indices": batch.req_pool_indices,
|
||||
"seq_lens": batch.seq_lens,
|
||||
"prefix_lens": batch.prefix_lens,
|
||||
"position_ids_offsets": batch.position_ids_offsets,
|
||||
"out_cache_loc": batch.out_cache_loc,
|
||||
}
|
||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
||||
return self.forward_extend_multi_modal(**kwargs)
|
||||
else:
|
||||
kwargs = {
|
||||
"input_ids": batch.input_ids,
|
||||
"req_pool_indices": batch.req_pool_indices,
|
||||
"seq_lens": batch.seq_lens,
|
||||
"prefix_lens": batch.prefix_lens,
|
||||
"position_ids_offsets": batch.position_ids_offsets,
|
||||
"out_cache_loc": batch.out_cache_loc,
|
||||
}
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
|
||||
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
||||
return self.forward_decode(**kwargs)
|
||||
elif forward_mode == ForwardMode.EXTEND:
|
||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
||||
return self.forward_extend(**kwargs)
|
||||
elif forward_mode == ForwardMode.PREFILL:
|
||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
||||
return self.forward_prefill(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||
220
python/sglang/srt/managers/router/radix_cache.py
Normal file
220
python/sglang/srt/managers/router/radix_cache.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TreeNode:
|
||||
def __init__(self):
|
||||
self.children = defaultdict(TreeNode)
|
||||
self.parent = None
|
||||
self.value = None
|
||||
self.ref_counter = 0
|
||||
self.last_access_time = time.time()
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def match(key, seq):
|
||||
i = 0
|
||||
for k, w in zip(key, seq):
|
||||
if k != w:
|
||||
break
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
class RadixCache:
|
||||
def __init__(self, disable=False):
|
||||
self.root_node = TreeNode()
|
||||
self.root_node.value = []
|
||||
self.root_node.ref_counter = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
self.disable = disable
|
||||
|
||||
##### Public API #####
|
||||
def match_prefix(self, key):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
value = []
|
||||
last_node = [self.root_node]
|
||||
self._match_prefix_helper(self.root_node, key, value, last_node)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
return value, last_node[0]
|
||||
|
||||
def insert(self, key, value=None):
|
||||
if self.disable:
|
||||
return len(key)
|
||||
|
||||
if value is None:
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def pretty_print(self):
|
||||
self._print_helper(self.root_node, 0)
|
||||
print(f"#tokens: {self.total_size()}")
|
||||
|
||||
def total_size(self):
|
||||
return self._total_size_helper(self.root_node)
|
||||
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
if self.disable:
|
||||
raise RuntimeError()
|
||||
|
||||
leaves = self._collect_leaves()
|
||||
heapq.heapify(leaves)
|
||||
|
||||
num_evicted = 0
|
||||
while num_evicted < num_tokens and len(leaves):
|
||||
x = heapq.heappop(leaves)
|
||||
|
||||
if x == self.root_node:
|
||||
break
|
||||
if x.ref_counter > 0:
|
||||
continue
|
||||
|
||||
num_evicted += evict_callback(x.value)
|
||||
self._delete_leaf(x)
|
||||
|
||||
if len(x.parent.children) == 0:
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
def inc_ref_counter(self, node):
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.ref_counter == 0:
|
||||
self.evictable_size_ -= len(node.value)
|
||||
delta -= len(node.value)
|
||||
node.ref_counter += 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def dec_ref_counter(self, node):
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.ref_counter == 1:
|
||||
self.evictable_size_ += len(node.value)
|
||||
delta += len(node.value)
|
||||
node.ref_counter -= 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def evictable_size(self):
|
||||
return self.evictable_size_
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
def _match_prefix_helper(self, node, key, value, last_node):
|
||||
node.last_access_time = time.time()
|
||||
|
||||
for c_key, child in node.children.items():
|
||||
prefix_len = match(c_key, key)
|
||||
if prefix_len != 0:
|
||||
if prefix_len == len(key) and prefix_len != len(c_key):
|
||||
new_node = self._split_node(c_key, child, prefix_len)
|
||||
value.append(new_node.value)
|
||||
last_node[0] = new_node
|
||||
else:
|
||||
value.append(child.value[:prefix_len])
|
||||
last_node[0] = child
|
||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
||||
break
|
||||
|
||||
def _split_node(self, key, child, split_len):
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {key[split_len:]: child}
|
||||
new_node.parent = child.parent
|
||||
new_node.ref_counter = child.ref_counter
|
||||
new_node.value = child.value[:split_len]
|
||||
child.parent = new_node
|
||||
child.value = child.value[split_len:]
|
||||
new_node.parent.children[key[:split_len]] = new_node
|
||||
del new_node.parent.children[key]
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node, key, value):
|
||||
node.last_access_time = time.time()
|
||||
|
||||
for c_key, child in node.children.items():
|
||||
prefix_len = match(c_key, key)
|
||||
|
||||
if prefix_len == len(c_key):
|
||||
if prefix_len == len(key):
|
||||
return prefix_len
|
||||
else:
|
||||
key = key[prefix_len:]
|
||||
value = value[prefix_len:]
|
||||
return prefix_len + self._insert_helper(child, key, value)
|
||||
|
||||
if prefix_len:
|
||||
new_node = self._split_node(c_key, child, prefix_len)
|
||||
return prefix_len + self._insert_helper(
|
||||
new_node, key[prefix_len:], value[prefix_len:]
|
||||
)
|
||||
|
||||
if len(key):
|
||||
new_node = TreeNode()
|
||||
new_node.parent = node
|
||||
new_node.value = value
|
||||
node.children[key] = new_node
|
||||
self.evictable_size_ += len(value)
|
||||
return 0
|
||||
|
||||
def _print_helper(self, node, indent):
|
||||
for key, child in node.children.items():
|
||||
print(" " * indent, len(key), key[:10], f"r={child.ref_counter}")
|
||||
self._print_helper(child, indent=indent + 2)
|
||||
|
||||
def _delete_leaf(self, node):
|
||||
for k, v in node.parent.children.items():
|
||||
if v == node:
|
||||
break
|
||||
del node.parent.children[k]
|
||||
self.evictable_size_ -= len(k)
|
||||
|
||||
def _total_size_helper(self, node):
|
||||
x = len(node.value)
|
||||
for child in node.children.values():
|
||||
x += self._total_size_helper(child)
|
||||
return x
|
||||
|
||||
def _collect_leaves(self):
|
||||
ret_list = []
|
||||
|
||||
def dfs_(cur_node):
|
||||
if len(cur_node.children) == 0:
|
||||
ret_list.append(cur_node)
|
||||
|
||||
for x in cur_node.children.values():
|
||||
dfs_(x)
|
||||
|
||||
dfs_(self.root_node)
|
||||
return ret_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache(disable=False)
|
||||
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello_L.A.!")
|
||||
# tree.insert("Hello_world! Happy")
|
||||
# tree.insert("I love you!")
|
||||
tree.pretty_print()
|
||||
|
||||
# print(tree.match_prefix("I love you! aha"))
|
||||
|
||||
# def evict_callback(x):
|
||||
# print("evict", x)
|
||||
# return len(x)
|
||||
|
||||
# tree.evict(5, evict_callback)
|
||||
# tree.evict(10, evict_callback)
|
||||
# tree.pretty_print()
|
||||
73
python/sglang/srt/managers/router/scheduler.py
Normal file
73
python/sglang/srt/managers/router/scheduler.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(
|
||||
self,
|
||||
schedule_heuristic,
|
||||
max_running_seq,
|
||||
max_prefill_num_token,
|
||||
max_total_num_token,
|
||||
tree_cache,
|
||||
):
|
||||
self.schedule_heuristic = schedule_heuristic
|
||||
self.max_running_seq = max_running_seq
|
||||
self.max_prefill_num_token = max_prefill_num_token
|
||||
self.max_total_num_token = max_total_num_token
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def new_token_estimation_ratio(self):
|
||||
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
|
||||
|
||||
def get_priority_queue(self, forward_queue):
|
||||
if self.schedule_heuristic == "lpm":
|
||||
# longest prefix match
|
||||
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
|
||||
return forward_queue
|
||||
elif self.schedule_heuristic == "random":
|
||||
random.shuffle(forward_queue)
|
||||
return forward_queue
|
||||
elif self.schedule_heuristic == "fcfs":
|
||||
return forward_queue
|
||||
elif self.schedule_heuristic == "weight":
|
||||
last_node_to_reqs = defaultdict(list)
|
||||
for req in forward_queue:
|
||||
last_node_to_reqs[req.last_node].append(req)
|
||||
for node in last_node_to_reqs:
|
||||
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
|
||||
|
||||
node_to_weight = defaultdict(int)
|
||||
self._calc_weight_recursive(
|
||||
self.tree_cache.root_node, last_node_to_reqs, node_to_weight
|
||||
)
|
||||
|
||||
tmp_queue = []
|
||||
self._get_weight_priority_recursive(
|
||||
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
|
||||
)
|
||||
assert len(tmp_queue) == len(forward_queue)
|
||||
return tmp_queue
|
||||
else:
|
||||
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
||||
|
||||
def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
|
||||
node_to_weight[cur_node] = 1
|
||||
if cur_node in last_node_to_reqs:
|
||||
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
|
||||
for child in cur_node.children.values():
|
||||
self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
|
||||
node_to_weight[cur_node] += node_to_weight[child]
|
||||
|
||||
def _get_weight_priority_recursive(
|
||||
self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
|
||||
):
|
||||
visit_list = [child for child in cur_node.children.values()]
|
||||
visit_list.sort(key=lambda x: -node_to_wight[x])
|
||||
# for node in visit_list:
|
||||
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
|
||||
for child in visit_list:
|
||||
self._get_weight_priority_recursive(
|
||||
child, node_to_wight, last_node_to_reqs, tmp_queue
|
||||
)
|
||||
tmp_queue.extend(last_node_to_reqs[cur_node])
|
||||
219
python/sglang/srt/managers/tokenizer_manager.py
Normal file
219
python/sglang/srt/managers/tokenizer_manager.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReqState:
|
||||
out_list: List
|
||||
finished: bool
|
||||
event: asyncio.Event
|
||||
lock: asyncio.Lock
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(image_data, processor=None):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image = load_image(image_data)
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.send_to_router = context.socket(zmq.PUSH)
|
||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||
|
||||
self.model_path = server_args.model_path
|
||||
self.hf_config = get_config(
|
||||
self.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
if is_multimodal_model(self.model_path):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor, initargs=(server_args,)
|
||||
)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state = {} # Dict[str -> ReqState]
|
||||
|
||||
async def get_pixel_values(self, image_data):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor, get_pixel_values, image_data
|
||||
)
|
||||
else:
|
||||
return get_pixel_values(image_data, self.processor)
|
||||
|
||||
async def generate_request(self, obj: GenerateReqInput):
|
||||
if self.to_create_loop:
|
||||
await self.create_handle_loop()
|
||||
|
||||
is_single = isinstance(obj.text, str)
|
||||
|
||||
if is_single:
|
||||
rid = obj.rid
|
||||
input_ids = self.tokenizer.encode(obj.text)
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data is None:
|
||||
pixel_values, image_hash = None, None
|
||||
else:
|
||||
pixel_values, image_hash = await self.get_pixel_values(obj.image_data)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
sampling_params=sampling_params,
|
||||
return_normalized_logprob=obj.return_normalized_logprob,
|
||||
normalized_logprob_start_len=obj.normalized_logprob_start_len,
|
||||
stream=obj.stream,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
lock = asyncio.Lock()
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, lock)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
while True:
|
||||
await event.wait()
|
||||
yield state.out_list[-1]
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
del self.rid_to_state[rid]
|
||||
break
|
||||
event.clear()
|
||||
else:
|
||||
assert obj.stream is False
|
||||
bs = len(obj.text)
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
input_ids = self.tokenizer.encode(obj.text[i])
|
||||
sampling_params = SamplingParams(**obj.sampling_params[i])
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data[i] is None:
|
||||
pixel_values, image_hash = None, None
|
||||
else:
|
||||
pixel_values, image_hash = await self.get_pixel_values(
|
||||
obj.image_data[i]
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
sampling_params=sampling_params,
|
||||
return_normalized_logprob=obj.return_normalized_logprob[i],
|
||||
normalized_logprob_start_len=obj.normalized_logprob_start_len[i],
|
||||
stream=obj.stream,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
lock = asyncio.Lock()
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, lock)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
output_list = []
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
state = self.rid_to_state[rid]
|
||||
await state.event.wait()
|
||||
output_list.append(state.out_list[-1])
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
yield output_list
|
||||
|
||||
async def create_handle_loop(self):
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(self.handle_loop())
|
||||
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
out_dict = {
|
||||
"text": recv_obj.output_str[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state = self.rid_to_state[rid]
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished[i]
|
||||
state.event.set()
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
111
python/sglang/srt/memory_pool.py
Normal file
111
python/sglang/srt/memory_pool.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Memory pool."""
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
def __init__(self, size, max_context_len):
|
||||
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
|
||||
self.can_use_mem_size = size
|
||||
self.req_to_token = torch.empty(
|
||||
(size, max_context_len), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
def alloc(self, need_size):
|
||||
if need_size > self.can_use_mem_size:
|
||||
return None
|
||||
|
||||
select_index = torch.nonzero(self.mem_state).squeeze(1)[:need_size]
|
||||
self.mem_state[select_index] = 0
|
||||
self.can_use_mem_size -= need_size
|
||||
return select_index.to(torch.int32)
|
||||
|
||||
def free(self, free_index):
|
||||
if isinstance(free_index, (int,)):
|
||||
self.can_use_mem_size += 1
|
||||
else:
|
||||
self.can_use_mem_size += free_index.shape[0]
|
||||
self.mem_state[free_index] = 1
|
||||
|
||||
# if self.can_use_mem_size == len(self.mem_state):
|
||||
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(1)
|
||||
self.can_use_mem_size = len(self.mem_state)
|
||||
|
||||
|
||||
class TokenToKVPool:
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
|
||||
self.alloc_ct = 0
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id):
|
||||
return self.kv_data[layer_id][:, 0]
|
||||
|
||||
def get_value_buffer(self, layer_id):
|
||||
return self.kv_data[layer_id][:, 1]
|
||||
|
||||
def alloc(self, need_size):
|
||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
|
||||
if select_index.shape[0] < need_size:
|
||||
return None
|
||||
|
||||
self.add_refs(select_index)
|
||||
return select_index.to(torch.int32)
|
||||
|
||||
def alloc_contiguous(self, need_size):
|
||||
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
|
||||
if empty_index.shape[0] < need_size:
|
||||
return None
|
||||
empty_size = len(empty_index)
|
||||
loc_sum = (
|
||||
empty_index[need_size - 1 :] - empty_index[: empty_size - (need_size - 1)]
|
||||
)
|
||||
can_used_loc = empty_index[: empty_size - (need_size - 1)][
|
||||
loc_sum == need_size - 1
|
||||
]
|
||||
if can_used_loc.shape[0] == 0:
|
||||
return None
|
||||
|
||||
start_loc = can_used_loc[0].item()
|
||||
select_index = torch.arange(start_loc, start_loc + need_size, device="cuda")
|
||||
self.add_refs(select_index)
|
||||
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
||||
|
||||
def free(self, free_index):
|
||||
return self.decrease_refs(free_index)
|
||||
|
||||
def used_size(self):
|
||||
return len(torch.nonzero(self.mem_state).squeeze(1))
|
||||
|
||||
def available_size(self):
|
||||
return torch.sum(self.mem_state == 0).item()
|
||||
|
||||
def add_refs(self, token_index: torch.Tensor):
|
||||
self.alloc_ct += len(token_index)
|
||||
self.mem_state[token_index] += 1
|
||||
|
||||
def decrease_refs(self, token_index: torch.Tensor):
|
||||
self.alloc_ct -= len(token_index)
|
||||
self.mem_state[token_index] -= 1
|
||||
|
||||
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
||||
|
||||
# if self.alloc_ct == 0:
|
||||
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
|
||||
|
||||
return num_freed
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.alloc_ct = 0
|
||||
27
python/sglang/srt/model_config.py
Normal file
27
python/sglang/srt/model_config.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.revision = revision
|
||||
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
||||
|
||||
# Unify the config keys for hf_config
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
self.num_key_value_heads = self.hf_config.num_key_value_heads
|
||||
self.num_attention_heads = self.hf_config.num_attention_heads
|
||||
self.hidden_size = self.hf_config.hidden_size
|
||||
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
self.vocab_size = self.hf_config.vocab_size
|
||||
316
python/sglang/srt/models/llama2.py
Normal file
316
python/sglang/srt/models/llama2.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, linear_method=linear_method
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LlamaDecoderLayer(config, i, linear_method)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not skip_embed:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = LlamaModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
213
python/sglang/srt/models/llava.py
Normal file
213
python/sglang/srt/models/llava.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaForCausalLM
|
||||
from torch import nn
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel, LlavaConfig
|
||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
class LlavaLlamaForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vision_tower = None
|
||||
self.config.vision_config.hidden_size = config.mm_hidden_size
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, linear_method)
|
||||
|
||||
def pad_input_ids(self, input_ids, pad_value):
|
||||
pad_ids = pad_value * (
|
||||
(self.image_feature_len + len(pad_value)) // len(pad_value)
|
||||
)
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
new_input_ids = (
|
||||
input_ids[:offset]
|
||||
+ pad_ids[: self.image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, offset
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
bs = input_metadata.batch_size
|
||||
|
||||
# Embed text input
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Embed vision input
|
||||
need_vision = (
|
||||
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
# FIXME: We need to substract the length of the system prompt
|
||||
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
||||
need_vision = need_vision & has_pixel
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = torch.tensor(
|
||||
np.array([pixel_values[i] for i in range(bs) if need_vision[i]]),
|
||||
device=self.vision_tower.device,
|
||||
)
|
||||
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, output_hidden_states=True
|
||||
)
|
||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||
|
||||
selected_image_feature = image_outputs.hidden_states[
|
||||
self.vision_feature_layer
|
||||
]
|
||||
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||
)
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
continue
|
||||
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
pad_len, pad_dim = image_features[pt].shape
|
||||
dim = input_embeds.shape[1]
|
||||
assert (
|
||||
pad_dim == dim
|
||||
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
||||
# Fill in the placeholder for the image
|
||||
try:
|
||||
input_embeds[
|
||||
start_idx
|
||||
+ image_offsets[i] : start_idx
|
||||
+ image_offsets[i]
|
||||
+ pad_len
|
||||
] = image_features[pt]
|
||||
except RuntimeError as e:
|
||||
print(f"RuntimeError in llava image encoding: {e}")
|
||||
print(input_embeds.shape)
|
||||
print(start_idx, image_offsets[i])
|
||||
pt += 1
|
||||
|
||||
return self.language_model(
|
||||
input_embeds, positions, input_metadata, skip_embed=True
|
||||
)
|
||||
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return self.language_model(
|
||||
input_ids, positions, input_metadata, skip_embed=False
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
# load clip vision model by cfg['mm_vision_tower']:
|
||||
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
||||
vision_path = self.config.mm_vision_tower
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
vision_path, torch_dtype=torch.float16
|
||||
).cuda()
|
||||
self.vision_tower.eval()
|
||||
|
||||
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
||||
self.image_size = self.vision_tower.config.image_size
|
||||
self.patch_size = self.vision_tower.config.patch_size
|
||||
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
|
||||
if self.vision_feature_select_strategy == "patch":
|
||||
pass
|
||||
elif self.vision_feature_select_strategy == "cls_patch":
|
||||
self.image_feature_len += 1
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
|
||||
# load mm_projector
|
||||
# TODO: support TP?
|
||||
projector_weights = {
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
# FIXME: why projector weights read two times?
|
||||
if "projector" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
|
||||
first_call = True
|
||||
|
||||
|
||||
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
||||
global first_call
|
||||
if first_call:
|
||||
self.patch_embedding.cpu().float()
|
||||
first_call = False
|
||||
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
||||
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
||||
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
def monkey_path_clip_vision_embed_forward():
|
||||
import transformers
|
||||
|
||||
setattr(
|
||||
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
||||
"forward",
|
||||
clip_vision_embed_forward,
|
||||
)
|
||||
378
python/sglang/srt/models/mixtral.py
Normal file
378
python/sglang/srt/models/mixtral.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.ffn_dim = intermediate_size
|
||||
self.hidden_dim = hidden_size
|
||||
|
||||
self.w1 = ReplicatedLinear(
|
||||
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
|
||||
)
|
||||
self.w2 = ReplicatedLinear(
|
||||
self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
|
||||
)
|
||||
self.w3 = ReplicatedLinear(
|
||||
self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
|
||||
)
|
||||
|
||||
# TODO: Use vllm's SiluAndMul
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
w1_out, _ = self.w1(hidden_states)
|
||||
w1_out = self.act_fn(w1_out)
|
||||
w3_out, _ = self.w3(hidden_states)
|
||||
current_hidden_states = w1_out * w3_out
|
||||
current_hidden_states, _ = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
if self.tp_size > self.num_total_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {self.num_total_experts}."
|
||||
)
|
||||
# Split experts equally between ranks
|
||||
self.expert_indicies = np.array_split(
|
||||
range(self.num_total_experts), self.tp_size
|
||||
)[self.rank].tolist()
|
||||
if not self.expert_indicies:
|
||||
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
||||
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
MixtralMLP(
|
||||
self.num_total_experts,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
if idx in self.expert_indicies
|
||||
else None
|
||||
for idx in range(self.num_total_experts)
|
||||
]
|
||||
)
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size, self.num_total_experts, bias=False, linear_method=None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.top_k, dim=-1
|
||||
)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
final_hidden_states = None
|
||||
for expert_idx in self.expert_indicies:
|
||||
expert_layer = self.experts[expert_idx]
|
||||
expert_mask = selected_experts == expert_idx
|
||||
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
|
||||
|
||||
current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
|
||||
if final_hidden_states is None:
|
||||
final_hidden_states = current_hidden_states
|
||||
else:
|
||||
final_hidden_states.add_(current_hidden_states)
|
||||
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
layer_id: int = 0,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
layer_id: int = 0,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = MixtralAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class MixtralModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
# config.num_hidden_layers=16
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MixtralDecoderLayer(config, i, linear_method=linear_method)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not skip_embed:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = MixtralModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
skip_embed: bool = False,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=False
|
||||
):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if "block_sparse_moe.experts." in name and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
81
python/sglang/srt/sampling_params.py
Normal file
81
python/sglang/srt/sampling_params.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-6
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_new_tokens: int = 16,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
dtype: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
) -> None:
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.stop_strs = stop
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.ignore_eos = ignore_eos
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.dtype = dtype
|
||||
self.regex = regex
|
||||
|
||||
# Process some special cases
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
if self.top_k == -1:
|
||||
self.top_k = 1 << 30 # whole vocabulary
|
||||
if self.dtype == "int":
|
||||
self.stop_strs = [" ", "\n"]
|
||||
|
||||
def verify(self):
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}."
|
||||
)
|
||||
if not 0.0 < self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(
|
||||
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
||||
)
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}."
|
||||
)
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
|
||||
)
|
||||
if self.max_new_tokens < 0:
|
||||
raise ValueError(
|
||||
f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
|
||||
)
|
||||
|
||||
def normalize(self, tokenizer):
|
||||
# Process stop strings
|
||||
if self.stop_strs is None:
|
||||
self.stop_strs = []
|
||||
self.stop_str_max_len = 0
|
||||
else:
|
||||
if isinstance(self.stop_strs, str):
|
||||
self.stop_strs = [self.stop_strs]
|
||||
|
||||
stop_str_max_len = 0
|
||||
for stop_str in self.stop_strs:
|
||||
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
||||
self.stop_str_max_len = stop_str_max_len
|
||||
222
python/sglang/srt/server.py
Normal file
222
python/sglang/srt/server.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""SRT: SGLang Runtime"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
# Fix a Python bug
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import CompletionRequest
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import alloc_usable_network_port
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
result = {
|
||||
"model_path": tokenizer_manager.model_path,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
result_generator = tokenizer_manager.generate_request(obj)
|
||||
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
async for out in result_generator:
|
||||
yield (json.dumps(out) + "\0").encode("utf-8")
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
else:
|
||||
ret = await result_generator.__anext__()
|
||||
return ret
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def v1_completions(obj: CompletionRequest):
|
||||
assert obj.n == 1
|
||||
obj = GenerateReqInput(
|
||||
text=obj.prompt,
|
||||
sampling_params={
|
||||
"temperature": obj.temperature,
|
||||
"max_new_tokens": obj.max_tokens,
|
||||
"stop": obj.stop,
|
||||
},
|
||||
)
|
||||
ret = await generate_request(obj)
|
||||
return {
|
||||
"choices": [{"text": ret["text"]}],
|
||||
}
|
||||
|
||||
|
||||
def launch_server(server_args, pipe_finish_writer):
|
||||
global tokenizer_manager
|
||||
|
||||
# Allocate ports
|
||||
can_use_ports = alloc_usable_network_port(
|
||||
num=4 + server_args.tp_size, used_list=(server_args.port,)
|
||||
)
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=can_use_ports[0],
|
||||
router_port=can_use_ports[1],
|
||||
detokenizer_port=can_use_ports[2],
|
||||
nccl_port=can_use_ports[3],
|
||||
model_rpc_ports=can_use_ports[4:],
|
||||
)
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
|
||||
proc_router = mp.Process(
|
||||
target=start_router_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
pipe_router_writer,
|
||||
),
|
||||
)
|
||||
proc_router.start()
|
||||
proc_detoken = mp.Process(
|
||||
target=start_detokenizer_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
pipe_detoken_writer,
|
||||
),
|
||||
)
|
||||
proc_detoken.start()
|
||||
|
||||
# Wait for the model to finish loading
|
||||
router_init_state = pipe_router_reader.recv()
|
||||
detoken_init_state = pipe_detoken_reader.recv()
|
||||
|
||||
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_router.kill()
|
||||
proc_detoken.kill()
|
||||
print("router init state:", router_init_state)
|
||||
print("detoken init state:", detoken_init_state)
|
||||
sys.exit(1)
|
||||
|
||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
def launch_server():
|
||||
# Launch api server
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
|
||||
t = threading.Thread(target=launch_server)
|
||||
t.start()
|
||||
|
||||
if pipe_finish_writer:
|
||||
url = server_args.url()
|
||||
|
||||
success = False
|
||||
for i in range(60):
|
||||
try:
|
||||
res = requests.get(url + "/get_model_info", timeout=5)
|
||||
success = True
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
time.sleep(1)
|
||||
|
||||
if success:
|
||||
pipe_finish_writer.send("init ok")
|
||||
else:
|
||||
pipe_finish_writer.send(str(e))
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = True,
|
||||
mem_fraction_static: float = 0.9,
|
||||
tp_size: int = 1,
|
||||
model_mode: List[str] = (),
|
||||
schedule_heuristic: str = "lpm",
|
||||
random_seed: int = 42,
|
||||
log_level: str = "warning",
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
port = alloc_usable_network_port(1)[0]
|
||||
server_args = ServerArgs(
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
host=host,
|
||||
port=port,
|
||||
load_format=load_format,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
tp_size=tp_size,
|
||||
model_mode=model_mode,
|
||||
schedule_heuristic=schedule_heuristic,
|
||||
random_seed=random_seed,
|
||||
log_level=log_level,
|
||||
)
|
||||
self.url = server_args.url()
|
||||
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
proc = mp.Process(target=launch_server, args=(server_args, pipe_writer))
|
||||
proc.start()
|
||||
self.pid = proc.pid
|
||||
|
||||
init_state = pipe_reader.recv()
|
||||
if init_state != "init ok":
|
||||
self.shutdown()
|
||||
raise RuntimeError("Launch failed")
|
||||
|
||||
self.endpoint = RuntimeEndpoint(self.url)
|
||||
|
||||
def shutdown(self):
|
||||
if self.pid is not None:
|
||||
parent = psutil.Process(self.pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
child.kill()
|
||||
psutil.wait_procs(children, timeout=5)
|
||||
parent.kill()
|
||||
parent.wait(timeout=5)
|
||||
self.pid = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
138
python/sglang/srt/server_args.py
Normal file
138
python/sglang/srt/server_args.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs:
|
||||
model_path: str
|
||||
tokenizer_path: Optional[str] = None
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
load_format: str = "auto"
|
||||
tokenizer_mode: str = "auto"
|
||||
trust_remote_code: bool = True
|
||||
mem_fraction_static: float = 0.91
|
||||
tp_size: int = 1
|
||||
model_mode: List[str] = ()
|
||||
schedule_heuristic: str = "lpm"
|
||||
random_seed: int = 42
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-path",
|
||||
type=str,
|
||||
default=ServerArgs.tokenizer_path,
|
||||
help="The path of the tokenizer.",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default=ServerArgs.host)
|
||||
parser.add_argument("--port", type=int, default=ServerArgs.port)
|
||||
parser.add_argument(
|
||||
"--load-format",
|
||||
type=str,
|
||||
default=ServerArgs.load_format,
|
||||
choices=["auto", "pt", "safetensors", "npcache", "dummy"],
|
||||
help="The format of the model weights to load. "
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
"and fall back to the pytorch bin format if safetensors format "
|
||||
"is not available. "
|
||||
'"pt" will load the weights in the pytorch bin format. '
|
||||
'"safetensors" will load the weights in the safetensors format. '
|
||||
'"npcache" will load the weights in pytorch format and store '
|
||||
"a numpy cache to speed up the loading. "
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
"which is mainly for profiling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
default=ServerArgs.tokenizer_mode,
|
||||
choices=["auto", "slow"],
|
||||
help="Tokenizer mode. 'auto' will use the fast "
|
||||
"tokenizer if available, and 'slow' will "
|
||||
"always use the slow tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
default=ServerArgs.mem_fraction_static,
|
||||
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=ServerArgs.tp_size,
|
||||
help="Tensor parallelism degree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-mode",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="Model mode: [flashinfer, no-cache, aggressive-new-fill]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-heuristic",
|
||||
type=str,
|
||||
default=ServerArgs.schedule_heuristic,
|
||||
help="Schudule mode: [lpm, weight, random, fcfs]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
type=int,
|
||||
default=ServerArgs.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default=ServerArgs.log_level,
|
||||
help="Log level",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-log-stats",
|
||||
action="store_true",
|
||||
help="Disable logging throughput stats.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-stats-interval",
|
||||
type=int,
|
||||
default=ServerArgs.log_stats_interval,
|
||||
help="Log stats interval in second.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
def url(self):
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
tokenizer_port: int
|
||||
router_port: int
|
||||
detokenizer_port: int
|
||||
nccl_port: int
|
||||
model_rpc_ports: List[int]
|
||||
217
python/sglang/srt/utils.py
Normal file
217
python/sglang/srt/utils.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
is_show_cost_time = False
|
||||
|
||||
|
||||
def mark_cost_time(func_name):
|
||||
def inner_func(func):
|
||||
def time_func(*args, **kwargs):
|
||||
if dist.get_rank() in [0, 1] and is_show_cost_time:
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
ans = func(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
print(func_name, "cost time:", (time.time() - start_time) * 1000)
|
||||
return ans
|
||||
else:
|
||||
torch.cuda.synchronize()
|
||||
ans = func(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
return ans
|
||||
|
||||
return time_func
|
||||
|
||||
return inner_func
|
||||
|
||||
|
||||
time_mark = {}
|
||||
|
||||
|
||||
def mark_start(key):
|
||||
torch.cuda.synchronize()
|
||||
global time_mark
|
||||
time_mark[key] = time.time()
|
||||
return
|
||||
|
||||
|
||||
def mark_end(key, print_min_cost=0.0):
|
||||
torch.cuda.synchronize()
|
||||
global time_mark
|
||||
cost_time = (time.time() - time_mark[key]) * 1000
|
||||
if cost_time > print_min_cost:
|
||||
print(f"cost {key}:", cost_time)
|
||||
|
||||
|
||||
def calculate_time(show=False, min_cost_ms=0.0):
|
||||
def wrapper(func):
|
||||
def inner_func(*args, **kwargs):
|
||||
torch.cuda.synchronize()
|
||||
if show:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
if show:
|
||||
cost_time = (time.time() - start_time) * 1000
|
||||
if cost_time > min_cost_ms:
|
||||
print(f"Function {func.__name__} took {cost_time} ms to run.")
|
||||
return result
|
||||
|
||||
return inner_func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def alloc_usable_network_port(num, used_list=()):
|
||||
port_list = []
|
||||
for port in range(10000, 65536):
|
||||
if port in used_list:
|
||||
continue
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(("", port))
|
||||
port_list.append(port)
|
||||
except socket.error:
|
||||
pass
|
||||
|
||||
if len(port_list) == num:
|
||||
return port_list
|
||||
return None
|
||||
|
||||
|
||||
def get_exception_traceback():
|
||||
etype, value, tb = sys.exc_info()
|
||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
||||
return err_str
|
||||
|
||||
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
||||
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
for t_id in range(vocab_size):
|
||||
ss = tokenizer.decode(t_id).strip()
|
||||
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
||||
logit_bias[t_id] = -1e5
|
||||
# else:
|
||||
# print(ss, t_id)
|
||||
|
||||
return logit_bias
|
||||
|
||||
|
||||
def wrap_kernel_launcher(kernel):
|
||||
"""A faster launcher for triton kernels."""
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
kernels = kernel.cache[rank].values()
|
||||
kernel = next(iter(kernels))
|
||||
|
||||
# Different trition versions use different low-level names
|
||||
if hasattr(kernel, "cu_function"):
|
||||
kfunction = kernel.cu_function
|
||||
else:
|
||||
kfunction = kernel.function
|
||||
|
||||
if hasattr(kernel, "c_wrapper"):
|
||||
run = kernel.c_wrapper
|
||||
else:
|
||||
run = kernel.run
|
||||
|
||||
add_cluster_dim = True
|
||||
|
||||
def ret_func(grid, num_warps, *args):
|
||||
nonlocal add_cluster_dim
|
||||
|
||||
try:
|
||||
if add_cluster_dim:
|
||||
run(
|
||||
grid[0],
|
||||
grid[1],
|
||||
grid[2],
|
||||
num_warps,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
kernel.shared,
|
||||
0,
|
||||
kfunction,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
*args,
|
||||
)
|
||||
else:
|
||||
run(
|
||||
grid[0],
|
||||
grid[1],
|
||||
grid[2],
|
||||
num_warps,
|
||||
kernel.shared,
|
||||
0,
|
||||
kfunction,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
*args,
|
||||
)
|
||||
except TypeError:
|
||||
add_cluster_dim = not add_cluster_dim
|
||||
ret_func(grid, num_warps, *args)
|
||||
|
||||
return ret_func
|
||||
|
||||
|
||||
def is_multimodal_model(model):
|
||||
if isinstance(model, str):
|
||||
return "llava" in model
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
if isinstance(model, ModelConfig):
|
||||
return "llava" in model.path.lower()
|
||||
raise Exception("unrecognized type")
|
||||
|
||||
|
||||
def load_image(image_file):
|
||||
from PIL import Image
|
||||
|
||||
image = None
|
||||
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||
response = requests.get(image_file, timeout=timeout)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
||||
image = Image.open(image_file)
|
||||
elif image_file.startswith("data:"):
|
||||
image_file = image_url.split(",")[1]
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
else:
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
|
||||
return image
|
||||
324
python/sglang/test/test_programs.py
Normal file
324
python/sglang/test/test_programs.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
This file contains the SGL programs used for unit testing.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def test_few_shot_qa():
|
||||
@sgl.function
|
||||
def few_shot_qa(s, question):
|
||||
s += "The following are questions with answers.\n\n"
|
||||
s += "Q: What is the capital of France?\n"
|
||||
s += "A: Paris\n"
|
||||
s += "Q: What is the capital of Germany?\n"
|
||||
s += "A: Berlin\n"
|
||||
s += "Q: What is the capital of Italy?\n"
|
||||
s += "A: Rome\n"
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + sgl.gen("answer", stop="\n", temperature=0)
|
||||
|
||||
ret = few_shot_qa.run(question="What is the capital of the United States?")
|
||||
assert "washington" in ret["answer"].strip().lower(), f"answer: {ret['answer']}"
|
||||
|
||||
rets = few_shot_qa.run_batch(
|
||||
[
|
||||
{"question": "What is the capital of Japan?"},
|
||||
{"question": "What is the capital of the United Kingdom?"},
|
||||
{"question": "What is the capital city of China?"},
|
||||
],
|
||||
temperature=0.1,
|
||||
)
|
||||
answers = [x["answer"].strip().lower() for x in rets]
|
||||
assert answers == ["tokyo", "london", "beijing"], f"answers: {answers}"
|
||||
|
||||
|
||||
def test_mt_bench():
|
||||
@sgl.function
|
||||
def answer_mt_bench(s, question_1, question_2):
|
||||
s += sgl.system("You are a helpful assistant.")
|
||||
s += sgl.user(question_1)
|
||||
s += sgl.assistant(sgl.gen("answer_1"))
|
||||
with s.user():
|
||||
s += question_2
|
||||
with s.assistant():
|
||||
s += sgl.gen("answer_2")
|
||||
|
||||
question_1 = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
|
||||
question_2 = (
|
||||
"Rewrite your previous response. Start every sentence with the letter A."
|
||||
)
|
||||
ret = answer_mt_bench.run(
|
||||
question_1=question_1, question_2=question_2, temperature=0.7, max_new_tokens=64
|
||||
)
|
||||
assert len(ret.messages()) in [4, 5]
|
||||
|
||||
|
||||
def test_select(check_answer):
|
||||
@sgl.function
|
||||
def true_or_false(s, statement):
|
||||
s += "Determine whether the statement below is True, False, or Unknown.\n"
|
||||
s += "Statement: The capital of France is Pairs.\n"
|
||||
s += "Answer: True\n"
|
||||
s += "Statement: " + statement + "\n"
|
||||
s += "Answer:" + sgl.select("answer", ["True", "False", "Unknown"])
|
||||
|
||||
ret = true_or_false.run(
|
||||
statement="The capital of Germany is Berlin.",
|
||||
)
|
||||
if check_answer:
|
||||
assert ret["answer"] == "True", ret.text
|
||||
else:
|
||||
assert ret["answer"] in ["True", "False", "Unknown"]
|
||||
|
||||
ret = true_or_false.run(
|
||||
statement="The capital of Canada is Tokyo.",
|
||||
)
|
||||
if check_answer:
|
||||
assert ret["answer"] == "False", ret.text
|
||||
else:
|
||||
assert ret["answer"] in ["True", "False", "Unknown"]
|
||||
|
||||
ret = true_or_false.run(
|
||||
statement="Purple is a better color than green.",
|
||||
)
|
||||
if check_answer:
|
||||
assert ret["answer"] == "Unknown", ret.text
|
||||
else:
|
||||
assert ret["answer"] in ["True", "False", "Unknown"]
|
||||
|
||||
|
||||
def test_decode_int():
|
||||
@sgl.function
|
||||
def decode_int(s):
|
||||
s += "The number of hours in a day is " + sgl.gen_int("hours") + "\n"
|
||||
s += "The number of days in a year is " + sgl.gen_int("days") + "\n"
|
||||
|
||||
ret = decode_int.run(temperature=0.1)
|
||||
assert int(ret["hours"]) == 24, ret.text
|
||||
assert int(ret["days"]) == 365, ret.text
|
||||
|
||||
|
||||
def test_decode_json():
|
||||
@sgl.function
|
||||
def decode_json(s):
|
||||
s += "Generate a JSON object to describe the basic information of a city.\n"
|
||||
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen_string() + ",\n"
|
||||
s += ' "population": ' + sgl.gen_int() + ",\n"
|
||||
s += ' "area": ' + sgl.gen(dtype=int) + ",\n"
|
||||
s += ' "country": ' + sgl.gen_string() + ",\n"
|
||||
s += ' "timezone": ' + sgl.gen(dtype=str) + "\n"
|
||||
s += "}"
|
||||
|
||||
ret = decode_json.run()
|
||||
js_obj = json.loads(ret["json_output"])
|
||||
assert isinstance(js_obj["name"], str)
|
||||
assert isinstance(js_obj["population"], int)
|
||||
|
||||
|
||||
def test_expert_answer():
|
||||
@sgl.function
|
||||
def expert_answer(s, question):
|
||||
s += "Question: " + question + "\n"
|
||||
s += (
|
||||
"A good person to answer this question is"
|
||||
+ sgl.gen("expert", stop=[".", "\n"])
|
||||
+ ".\n"
|
||||
)
|
||||
s += (
|
||||
"For example,"
|
||||
+ s["expert"]
|
||||
+ " would answer that "
|
||||
+ sgl.gen("answer", stop=".")
|
||||
+ "."
|
||||
)
|
||||
|
||||
ret = expert_answer.run(question="What is the capital of France?", temperature=0.1)
|
||||
assert "paris" in ret.text().lower()
|
||||
|
||||
|
||||
def test_tool_use():
|
||||
def calculate(expression):
|
||||
return f"{eval(expression)}"
|
||||
|
||||
@sgl.function
|
||||
def tool_use(s, lhs, rhs):
|
||||
s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n"
|
||||
s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n"
|
||||
s += "Question: What is the product of " + lhs + " and " + rhs + "?\n"
|
||||
s += (
|
||||
"Answer: The answer is calculate("
|
||||
+ sgl.gen("expression", stop=")")
|
||||
+ ") = "
|
||||
)
|
||||
with s.var_scope("answer"):
|
||||
s += calculate(s["expression"])
|
||||
|
||||
lhs, rhs = 257, 983
|
||||
ret = tool_use(lhs=lhs, rhs=rhs, temperature=0)
|
||||
assert int(ret["answer"]) == lhs * rhs
|
||||
|
||||
|
||||
def test_react():
|
||||
@sgl.function
|
||||
def react(s, question):
|
||||
s += """
|
||||
Question: Which country does the founder of Microsoft live in?
|
||||
Thought 1: I need to search for the founder of Microsoft.
|
||||
Action 1: Search [Founder of Microsoft].
|
||||
Observation 1: The founder of Microsoft is Bill Gates.
|
||||
Thought 2: I need to search for the country where Bill Gates lives in.
|
||||
Action 2: Search [Where does Bill Gates live].
|
||||
Observation 2: Bill Gates lives in the United States.
|
||||
Thought 3: The answer is the United States.
|
||||
Action 3: Finish [United States].\n
|
||||
"""
|
||||
|
||||
s += "Question: " + question + "\n"
|
||||
|
||||
for i in range(1, 5):
|
||||
s += f"Thought {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
|
||||
s += f"Action {i}: " + sgl.select(f"action_{i}", ["Search", "Finish"])
|
||||
|
||||
if s[f"action_{i}"] == "Search":
|
||||
s += " [" + sgl.gen(stop="]") + "].\n"
|
||||
s += f"Observation {i}:" + sgl.gen(stop=[".", "\n"]) + ".\n"
|
||||
else:
|
||||
s += " [" + sgl.gen("answer", stop="]") + "].\n"
|
||||
break
|
||||
|
||||
ret = react.run(
|
||||
question="What country does the creator of Linux live in?",
|
||||
temperature=0.1,
|
||||
)
|
||||
answer = ret["answer"].lower()
|
||||
assert "finland" in answer or "states" in answer
|
||||
|
||||
|
||||
def test_parallel_decoding():
|
||||
max_tokens = 64
|
||||
number = 5
|
||||
|
||||
@sgl.function
|
||||
def parallel_decoding(s, topic):
|
||||
s += "Act as a helpful assistant.\n"
|
||||
s += "USER: Give some tips for " + topic + ".\n"
|
||||
s += (
|
||||
"ASSISTANT: Okay. Here are "
|
||||
+ str(number)
|
||||
+ " concise tips, each under 8 words:\n"
|
||||
)
|
||||
|
||||
# Generate skeleton
|
||||
for i in range(1, 1 + number):
|
||||
s += f"{i}." + sgl.gen(max_tokens=16, stop=[".", "\n"]) + ".\n"
|
||||
|
||||
# Generate detailed tips
|
||||
forks = s.fork(number)
|
||||
for i in range(number):
|
||||
forks[
|
||||
i
|
||||
] += f"Now, I expand tip {i+1} into a detailed paragraph:\nTip {i+1}:"
|
||||
forks[i] += sgl.gen("detailed_tip", max_tokens, stop=["\n\n"])
|
||||
forks.join()
|
||||
|
||||
# Concatenate tips and summarize
|
||||
s += "Here are these tips with detailed explanation:\n"
|
||||
for i in range(number):
|
||||
s += f"Tip {i+1}:" + forks[i]["detailed_tip"] + "\n"
|
||||
|
||||
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
||||
|
||||
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
|
||||
|
||||
|
||||
def test_parallel_encoding(check_answer=True):
|
||||
max_tokens = 64
|
||||
|
||||
@sgl.function
|
||||
def parallel_encoding(s, question, context_0, context_1, context_2):
|
||||
s += "USER: I will ask a question based on some statements.\n"
|
||||
s += "ASSISTANT: Sure. I will give the answer.\n"
|
||||
s += "USER: Please memorize these statements.\n"
|
||||
|
||||
contexts = [context_0, context_1, context_2]
|
||||
|
||||
forks = s.fork(len(contexts))
|
||||
forks += lambda i: f"Statement {i}: " + contexts[i] + "\n"
|
||||
forks.join(mode="concate_and_append")
|
||||
|
||||
s += "Now, please answer the following question. " "Do not list options."
|
||||
s += "\nQuestion: " + question + "\n"
|
||||
s += "ASSISTANT:" + sgl.gen("answer", max_tokens=max_tokens)
|
||||
|
||||
ret = parallel_encoding.run(
|
||||
question="Who is the father of Julian?",
|
||||
context_0="Ethan is the father of Liam.",
|
||||
context_1="Noah is the father of Julian.",
|
||||
context_2="Oliver is the father of Carlos.",
|
||||
temperature=0,
|
||||
)
|
||||
answer = ret["answer"]
|
||||
|
||||
if check_answer:
|
||||
assert "Noah" in answer
|
||||
|
||||
|
||||
def test_image_qa():
|
||||
@sgl.function
|
||||
def image_qa(s, question):
|
||||
s += sgl.user(sgl.image("image.png") + question)
|
||||
s += sgl.assistant(sgl.gen("answer"))
|
||||
|
||||
state = image_qa.run(
|
||||
question="Please describe this image in simple words.",
|
||||
temperature=0,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
assert "taxi" in state.messages()[-1]["content"]
|
||||
|
||||
|
||||
def test_stream():
|
||||
@sgl.function
|
||||
def qa(s, question):
|
||||
s += sgl.user(question)
|
||||
s += sgl.assistant(sgl.gen("answer"))
|
||||
|
||||
ret = qa(
|
||||
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
|
||||
stream=True,
|
||||
)
|
||||
out = ""
|
||||
for chunk in ret.text_iter():
|
||||
out += chunk
|
||||
|
||||
ret = qa(
|
||||
question="Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.",
|
||||
stream=True,
|
||||
)
|
||||
out = ""
|
||||
for chunk in ret.text_iter("answer"):
|
||||
out += chunk
|
||||
|
||||
|
||||
def test_regex():
|
||||
regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
@sgl.function
|
||||
def regex_gen(s):
|
||||
s += "Q: What is the IP address of the Google DNS servers?\n"
|
||||
s += "A: " + sgl.gen(
|
||||
"answer",
|
||||
temperature=0,
|
||||
regex=regex,
|
||||
)
|
||||
|
||||
state = regex_gen.run()
|
||||
answer = state["answer"]
|
||||
assert re.match(regex, answer)
|
||||
141
python/sglang/test/test_utils.py
Normal file
141
python/sglang/test/test_utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Common utilities for testing and benchmarking"""
|
||||
import numpy as np
|
||||
import requests
|
||||
from sglang.backend.openai import OpenAI
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.global_config import global_config
|
||||
|
||||
|
||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_tokens,
|
||||
"stop_sequences": stop,
|
||||
},
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
pred = res.json()["generated_text"][0]
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
"n": n,
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
if n == 1:
|
||||
pred = res.json()["text"][0][len(prompt) :]
|
||||
else:
|
||||
pred = [x[len(prompt) :] for x in res.json()["text"]]
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
||||
data = {
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
},
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
obj = res.json()
|
||||
pred = obj["text"]
|
||||
return pred
|
||||
|
||||
|
||||
def call_select_lightllm(context, choices, url):
|
||||
scores = []
|
||||
for i in range(len(choices)):
|
||||
data = {
|
||||
"inputs": context + choices[i],
|
||||
"parameters": {
|
||||
"max_new_tokens": 1,
|
||||
},
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
scores.append(0)
|
||||
return np.argmax(scores)
|
||||
|
||||
|
||||
def call_select_vllm(context, choices, url):
|
||||
scores = []
|
||||
for i in range(len(choices)):
|
||||
data = {
|
||||
"prompt": context + choices[i],
|
||||
"max_tokens": 1,
|
||||
"prompt_logprobs": 1,
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
scores.append(res.json()["prompt_score"])
|
||||
return np.argmax(scores)
|
||||
|
||||
"""
|
||||
Modify vllm/entrypoints/api_server.py
|
||||
|
||||
if final_output.prompt_logprobs is not None:
|
||||
score = np.mean([prob[t_id] for t_id, prob in zip(final_output.prompt_token_ids[1:], final_output.prompt_logprobs[1:])])
|
||||
ret["prompt_score"] = score
|
||||
"""
|
||||
|
||||
|
||||
def add_common_other_args_and_parse(parser):
|
||||
parser.add_argument("--parallel", type=int, default=96)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
|
||||
)
|
||||
parser.add_argument("--result-file", type=str, default="result.jsonl")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.port is None:
|
||||
default_port = {
|
||||
"vllm": 21000,
|
||||
"lightllm": 22000,
|
||||
"lmql": 23000,
|
||||
"srt-raw": 30000,
|
||||
}
|
||||
args.port = default_port.get(args.backend, None)
|
||||
return args
|
||||
|
||||
|
||||
def add_common_sglang_args_and_parse(parser):
|
||||
parser.add_argument("--parallel", type=int, default=64)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
parser.add_argument("--backend", type=str, default="srt")
|
||||
parser.add_argument("--result-file", type=str, default="result.jsonl")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def select_sglang_backend(args):
|
||||
if args.backend.startswith("srt"):
|
||||
if args.backend == "srt-no-parallel":
|
||||
global_config.enable_parallel_decoding = False
|
||||
global_config.enable_parallel_encoding = False
|
||||
backend = RuntimeEndpoint(f"{args.host}:{args.port}")
|
||||
elif args.backend.startswith("gpt"):
|
||||
backend = OpenAI(args.backend)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
return backend
|
||||
179
python/sglang/utils.py
Normal file
179
python/sglang/utils.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import threading
|
||||
import urllib.request
|
||||
from io import BytesIO
|
||||
from json import dumps
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def get_available_gpu_memory(gpu_id, distributed=True):
|
||||
"""
|
||||
Get available memory for cuda:gpu_id device.
|
||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||
"""
|
||||
import torch
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
assert gpu_id < num_gpus
|
||||
|
||||
if torch.cuda.current_device() != gpu_id:
|
||||
print(
|
||||
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
||||
"which may cause useless memory allocation for torch CUDA context.",
|
||||
)
|
||||
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
||||
|
||||
if distributed:
|
||||
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
||||
torch.device("cuda", gpu_id)
|
||||
)
|
||||
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
||||
free_gpu_memory = tensor.item()
|
||||
|
||||
return free_gpu_memory / (1 << 30)
|
||||
|
||||
|
||||
def is_same_type(values):
|
||||
"""Return whether the elements in values are of the same type."""
|
||||
if len(values) <= 1:
|
||||
return True
|
||||
else:
|
||||
t = type(values[0])
|
||||
return all(isinstance(v, t) for v in values[1:])
|
||||
|
||||
|
||||
def read_jsonl(filename: str):
|
||||
"""Read a JSONL file."""
|
||||
rets = []
|
||||
with open(filename) as fin:
|
||||
for line in fin:
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
rets.append(json.loads(line))
|
||||
return rets
|
||||
|
||||
|
||||
def dump_state_text(filename, states, mode="w"):
|
||||
"""Dump program state in a text file."""
|
||||
from sglang.lang.interpreter import ProgramState
|
||||
|
||||
with open(filename, mode) as fout:
|
||||
for i, s in enumerate(states):
|
||||
if isinstance(s, str):
|
||||
pass
|
||||
elif isinstance(s, ProgramState):
|
||||
s = s.text().strip()
|
||||
else:
|
||||
s = str(s)
|
||||
|
||||
fout.write(
|
||||
"=" * 40 + f" {i} " + "=" * 40 + "\n" + s + "\n" + "=" * 80 + "\n\n"
|
||||
)
|
||||
|
||||
|
||||
class HttpResponse:
|
||||
def __init__(self, resp):
|
||||
self.resp = resp
|
||||
|
||||
def json(self):
|
||||
return json.loads(self.resp.read())
|
||||
|
||||
@property
|
||||
def status_code(self):
|
||||
return self.resp.status
|
||||
|
||||
|
||||
def http_request(url, json=None, stream=False):
|
||||
"""A faster version of requests.post with low-level urllib API."""
|
||||
if stream:
|
||||
return requests.post(url, json=json, stream=True)
|
||||
else:
|
||||
req = urllib.request.Request(url)
|
||||
req.add_header("Content-Type", "application/json; charset=utf-8")
|
||||
if json is None:
|
||||
data = None
|
||||
else:
|
||||
data = bytes(dumps(json), encoding="utf-8")
|
||||
resp = urllib.request.urlopen(req, data=data)
|
||||
return HttpResponse(resp)
|
||||
|
||||
|
||||
def encode_image_base64(image_path):
|
||||
"""Encode an image in base64."""
|
||||
if isinstance(image_path, str):
|
||||
with open(image_path, "rb") as image_file:
|
||||
data = image_file.read()
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
elif isinstance(image_path, bytes):
|
||||
return base64.b64encode(image_path).decode("utf-8")
|
||||
else:
|
||||
# image_path is PIL.WebPImagePlugin.WebPImageFile
|
||||
image = image_path
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def _is_chinese_char(cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if (
|
||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||
): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def find_printable_text(text):
|
||||
"""Returns the longest printable substring of text that contains only entire words."""
|
||||
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
|
||||
|
||||
# After the symbol for a new line, we flush the cache.
|
||||
if text.endswith("\n"):
|
||||
return text
|
||||
# If the last token is a CJK character, we print the characters.
|
||||
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
|
||||
return text
|
||||
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
||||
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
||||
else:
|
||||
return text[: text.rfind(" ") + 1]
|
||||
|
||||
|
||||
def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
||||
"""Run a function with timeout."""
|
||||
ret_value = []
|
||||
|
||||
def _target_func():
|
||||
ret_value.append(func(*args, **(kwargs or {})))
|
||||
|
||||
t = threading.Thread(target=_target_func)
|
||||
t.start()
|
||||
t.join(timeout=timeout)
|
||||
if t.is_alive():
|
||||
raise TimeoutError()
|
||||
|
||||
if not ret_value:
|
||||
raise RuntimeError()
|
||||
|
||||
return ret_value[0]
|
||||
Reference in New Issue
Block a user