sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

167
python/pyproject.toml Executable file
View File

@@ -0,0 +1,167 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "sglang"
version = "0.5.2"
description = "SGLang is a fast serving framework for large language models and vision language models."
readme = "README.md"
requires-python = ">=3.10"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle"]
[project.optional-dependencies]
runtime_common = [
"blobfile==3.0.0",
"build",
"compressed-tensors",
"datasets",
"einops",
"fastapi",
"hf_transfer",
"huggingface_hub",
"interegular",
"llguidance>=0.7.11,<0.8.0",
"modelscope",
"msgspec",
"ninja",
"openai==1.99.1",
"openai-harmony==0.0.4",
"orjson",
"outlines==0.1.11",
"packaging",
"partial_json_parser",
"pillow",
"prometheus-client>=0.20.0",
"psutil",
"pybase64",
"pydantic",
"pynvml",
"python-multipart",
"pyzmq>=25.1.2",
"sentencepiece",
"soundfile==0.13.1",
"scipy",
"timm==1.0.16",
"tiktoken",
"torchao==0.9.0",
"transformers==4.56.1",
"uvicorn",
"uvloop",
"xgrammar==0.1.24",
]
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.9.post2",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"flashinfer_python==0.3.1",
]
blackwell = [
"sglang[runtime_common]",
"sgl-kernel==0.3.9.post2",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"flashinfer_python==0.3.1",
"nvidia-cutlass-dsl==4.1.0",
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
srt_hip = [
"sglang[runtime_common]",
"torch",
"petit_kernel==0.0.2",
"wave-lang==3.7.0",
]
# https://docs.sglang.ai/platforms/cpu_server.html
srt_cpu = ["sglang[runtime_common]", "intel-openmp"]
# https://docs.sglang.ai/platforms/ascend_npu.html
srt_npu = ["sglang[runtime_common]"]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt_xpu = ["sglang[runtime_common]"]
# For Intel Gaudi(device : hpu) follow the installation guide
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
srt_hpu = ["sglang[runtime_common]"]
openai = ["openai==1.99.1", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver==0.0.8"]
decord = ["decord"]
test = [
"accelerate",
"expecttest",
"jsonlines",
"matplotlib",
"pandas",
"peft",
"sentence_transformers",
"pytest",
"tabulate",
]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"]
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
dev = ["sglang[all]", "sglang[test]"]
dev_hip = ["sglang[all_hip]", "sglang[test]"]
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
dev_hpu = ["sglang[all_hpu]", "sglang[test]"]
dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang"
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.package-data]
"sglang" = [
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
"srt/layers/quantization/configs/*.json",
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
]
[tool.setuptools.packages.find]
exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
[tool.wheel]
exclude = [
"assets*",
"benchmark*",
"docs*",
"dist*",
"playground*",
"scripts*",
"tests*",
]
[tool.codespell]
ignore-words-list = "ans, als, hel, boostrap, childs, te, vas, hsa, ment"
skip = "*.json,*.jsonl,*.patch,*.txt"

16
python/sglang/README.md Normal file
View File

@@ -0,0 +1,16 @@
# Code Structures
- `eval`: The evaluation utilities.
- `lang`: The frontend language.
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
- `test`: The test utilities.
- `api.py`: The public APIs.
- `bench_offline_throughput.py`: Benchmark the performance in the offline mode.
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
- `bench_serving.py`: Benchmark online serving with dynamic requests.
- `check_env.py`: Check the environment variables and dependencies.
- `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server.
- `utils.py`: Common utilities.
- `version.py`: Version info.

83
python/sglang/__init__.py Normal file
View File

@@ -0,0 +1,83 @@
# SGLang public APIs
# Frontend Language APIs
from sglang.global_config import global_config
from sglang.lang.api import (
Engine,
Runtime,
assistant,
assistant_begin,
assistant_end,
flush_cache,
function,
gen,
gen_int,
gen_string,
get_server_info,
image,
select,
separate_reasoning,
set_default_backend,
system,
system_begin,
system_end,
user,
user_begin,
user_end,
video,
)
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)
# Lazy import some libraries
from sglang.utils import LazyImport
from sglang.version import __version__
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
# Runtime Engine APIs
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
__all__ = [
"Engine",
"Runtime",
"assistant",
"assistant_begin",
"assistant_end",
"flush_cache",
"function",
"gen",
"gen_int",
"gen_string",
"get_server_info",
"image",
"select",
"separate_reasoning",
"set_default_backend",
"system",
"system_begin",
"system_end",
"user",
"user_begin",
"user_end",
"video",
"RuntimeEndpoint",
"greedy_token_selection",
"token_length_normalized",
"unconditional_likelihood_normalized",
"ServerArgs",
"Anthropic",
"LiteLLM",
"OpenAI",
"VertexAI",
"global_config",
"__version__",
]

View File

@@ -0,0 +1,452 @@
"""
Benchmark the throughput in the offline mode.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
# Usage
## Sharegpt dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
## Random dataset with default args
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
"""
import argparse
import asyncio
import dataclasses
import inspect
import json
import logging
import os
import random
import time
from typing import Dict, List, Optional
import numpy as np
from sglang.bench_serving import (
DatasetRow,
get_dataset,
get_tokenizer,
sample_random_requests,
set_ulimit,
)
from sglang.lang.backend.runtime_endpoint import Runtime
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.server_args import ServerArgs
@dataclasses.dataclass
class BenchArgs:
backend: str = "engine"
result_filename: str = ""
dataset_name: str = "sharegpt"
dataset_path: str = ""
num_prompts: int = 1000
sharegpt_output_len: Optional[int] = None
sharegpt_context_len: Optional[int] = None
random_input_len: int = 1024
random_output_len: int = 1024
random_range_ratio: float = 0.0
gsp_num_groups: int = 64
gsp_prompts_per_group: int = 16
gsp_system_prompt_len: int = 2048
gsp_question_len: int = 128
gsp_output_len: int = 256
seed: int = 1
disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None
apply_chat_template: bool = False
profile: bool = False
skip_warmup: bool = False
do_not_exit: bool = False
prompt_suffix: str = ""
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
choices=["sharegpt", "random", "generated-shared-prefix"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
"--dataset-path", type=str, default="", help="Path to the dataset."
)
parser.add_argument(
"--num-prompts",
type=int,
default=BenchArgs.num_prompts,
help="Number of prompts to process. Default is 1000.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=BenchArgs.sharegpt_output_len,
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
)
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=BenchArgs.sharegpt_context_len,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=BenchArgs.random_input_len,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
type=int,
default=BenchArgs.random_output_len,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=BenchArgs.random_range_ratio,
help="Range of sampled ratio of input/output length, "
"used only for random dataset.",
)
parser.add_argument(
"--gsp-num-groups",
type=int,
default=BenchArgs.gsp_num_groups,
help="Number of groups with shared prefix, used"
"only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-prompts-per-group",
type=int,
default=BenchArgs.gsp_prompts_per_group,
help="Number of prompts per group of shared prefix, used"
"only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-system-prompt-len",
type=int,
default=BenchArgs.gsp_system_prompt_len,
help="System prompt length, used" "only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-question-len",
type=int,
default=BenchArgs.gsp_question_len,
help="Question length, used" "only for generate-shared-prefix",
)
parser.add_argument(
"--gsp-output-len",
type=int,
default=BenchArgs.gsp_output_len,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignore EOS token",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
default=BenchArgs.extra_request_body,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument(
"--apply-chat-template",
action="store_true",
help="Apply chat template",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--skip-warmup",
action="store_true",
help="Skip the warmup batches.",
)
parser.add_argument(
"--do-not-exit",
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def throughput_test_once(
backend_name: str,
backend,
reqs: List[DatasetRow],
ignore_eos: bool,
extra_request_body: Dict,
profile: bool,
):
measurement_results = {
"backend": backend_name,
"successful_requests": len(reqs),
"total_latency": -1,
"total_input_tokens": sum(r.prompt_len for r in reqs),
"total_output_tokens": -1,
"request_throughput": -1,
"input_throughput": -1,
"output_throughput": -1,
"total_throughput": -1,
}
prompt = [r.prompt for r in reqs]
sampling_params = [
{
"temperature": 0,
"max_new_tokens": r.output_len,
"ignore_eos": ignore_eos,
**extra_request_body,
}
for r in reqs
]
if profile:
assert (
"SGLANG_TORCH_PROFILER_DIR" in os.environ
), "Please set SGLANG_TORCH_PROFILER_DIR."
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
backend.start_profile()
st = time.perf_counter()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
latency = time.perf_counter() - st
if profile:
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
known_files = set(os.listdir(dir))
backend.stop_profile()
monitor_trace_file(known_files, dir)
if backend_name == "runtime":
gen_out = json.loads(gen_out)
server_info = backend.get_server_info()
measurement_results["total_latency"] = latency
measurement_results["total_output_tokens"] = sum(
o["meta_info"]["completion_tokens"] for o in gen_out
)
measurement_results["request_throughput"] = (
measurement_results["successful_requests"] / latency
)
measurement_results["input_throughput"] = (
measurement_results["total_input_tokens"] / latency
)
measurement_results["output_throughput"] = (
measurement_results["total_output_tokens"] / latency
)
measurement_results["total_throughput"] = (
measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"]
) / latency
if inspect.isawaitable(server_info):
server_info = asyncio.run(server_info)
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
"last_gen_throughput"
]
return measurement_results
def monitor_trace_file(known_files, directory, interval=1):
print(f"Monitoring {directory} for new trace files...")
while True:
flag = False
time.sleep(interval)
current_files = set(os.listdir(directory))
new_files = current_files - known_files
for new_file in new_files:
new_file_path = os.path.join(directory, new_file)
print(f"New file detected: {new_file}")
previous_size = 0
while True:
try:
current_size = os.path.getsize(new_file_path)
except FileNotFoundError:
print(f"File {new_file} is no longer accessible.")
break
if current_size > previous_size:
previous_size = current_size
else:
flag = True
break
time.sleep(interval)
if flag:
break
def throughput_test(
server_args: ServerArgs,
bench_args: BenchArgs,
):
if bench_args.backend == "engine":
backend = Engine(**dataclasses.asdict(server_args))
if not backend:
raise ValueError("Please provide valid engine arguments")
elif bench_args.backend == "runtime":
backend = Runtime(**dataclasses.asdict(server_args))
else:
raise ValueError('Please set backend to either "engine" or "runtime"')
tokenizer_id = server_args.tokenizer_path or server_args.model_path
tokenizer = get_tokenizer(tokenizer_id)
# Set global environments
set_ulimit()
random.seed(bench_args.seed)
np.random.seed(bench_args.seed)
# Parse args
extra_request_body = {}
if bench_args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
# Read dataset
input_requests = get_dataset(bench_args, tokenizer)
warmup_requests = sample_random_requests(
input_len=256,
output_len=16,
num_prompts=min(bench_args.num_prompts, 16),
range_ratio=1.0,
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
)
# Warm up
if not bench_args.skip_warmup:
logging.info("\nWarmup...")
throughput_test_once(
backend_name=bench_args.backend,
backend=backend,
reqs=warmup_requests,
ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
profile=False,
)
time.sleep(0.5)
logging.info("\nBenchmark...")
result = throughput_test_once(
backend_name=bench_args.backend,
backend=backend,
reqs=input_requests,
ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body,
profile=bench_args.profile,
)
backend.shutdown()
if bench_args.result_filename:
with open(bench_args.result_filename, "a") as fout:
fout.write(json.dumps(result) + "\n")
print(
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
)
print("{:<40} {:<10}".format("Backend:", result["backend"]))
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
print(
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
)
print(
"{:<40} {:<10.2f}".format(
"Last generation throughput (tok/s):", result["last_gen_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", result["request_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Input token throughput (tok/s):", result["input_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", result["output_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Total token throughput (tok/s):", result["total_throughput"]
)
)
print("=" * 50)
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
# handling ModelScope model downloads
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
if os.path.exists(args.model_path):
print(f"Using local model path: {args.model_path}")
else:
try:
from modelscope import snapshot_download
print(f"Using ModelScope to download model: {args.model_path}")
# download the model and replace args.model_path
args.model_path = snapshot_download(
args.model_path,
)
print(f"Model downloaded to: {args.model_path}")
except Exception as e:
print(f"ModelScope download failed: {str(e)}")
raise e
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
throughput_test(server_args, bench_args)
while bench_args.do_not_exit:
pass

View File

@@ -0,0 +1,665 @@
"""
Benchmark the latency of running a single static batch without a server.
This script does not launch a server and uses the low-level APIs.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
# Usage (latency test)
## with dummy weights:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
## run with profiling:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
# Usage (correctness test):
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
## Reference output (of the correctness test above, can be gpu dependent):
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
device='cuda:0')
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
device='cuda:0')
========== Prompt 0 ==========
<s> The capital of France is Paris.
The capital of the United States is Washington, D.C.
========== Prompt 1 ==========
<s> The capital of the United Kindom is London.
The capital of the United Kingdom is London.
The capital of the
========== Prompt 2 ==========
<s> Today is a sunny day and I like to go for a walk in the park.
I'm going to the park
"""
import argparse
import copy
import dataclasses
import itertools
import json
import logging
import multiprocessing
import os
import time
from typing import Tuple
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
configure_logger,
get_bool_env_var,
kill_process_tree,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
suppress_other_loggers,
)
@dataclasses.dataclass
class BenchArgs:
run_name: str = "default"
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
prompt_filename: str = ""
result_filename: str = "result.jsonl"
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
log_decode_step: int = 0
profile: bool = False
profile_record_shapes: bool = False
profile_filename_prefix: str = "profile"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument(
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
)
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument(
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
parser.add_argument(
"--log-decode-step",
type=int,
default=BenchArgs.log_decode_step,
help="Log decode latency by step, default is set to zero to disable.",
)
parser.add_argument(
"--profile", action="store_true", help="Use Torch Profiler."
)
parser.add_argument(
"--profile-record-shapes",
action="store_true",
help="Record tensor shapes in profiling results.",
)
parser.add_argument(
"--profile-filename-prefix",
type=str,
default=BenchArgs.profile_filename_prefix,
help="Prefix of the profiling file names. The full profiling result file(s) be "
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def load_model(server_args, port_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
model_config = ModelConfig.from_server_args(server_args)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
moe_ep_rank=moe_ep_rank,
moe_ep_size=server_args.ep_size,
pp_rank=0,
pp_size=1,
nccl_port=port_args.nccl_port,
server_args=server_args,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
if server_args.tp_size > 1:
dist.barrier()
return model_runner, tokenizer
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
prompts = (
custom_prompts
if custom_prompts
else [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
)
input_ids = [tokenizer.encode(p) for p in prompts]
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(prompts)):
assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(
rid=i,
origin_input_text=prompts[i],
origin_input_ids=tmp_input_ids,
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req)
return input_ids, reqs
def prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
):
for i in range(len(reqs)):
req = reqs[i]
req.fill_ids += input_ids[i][bench_args.cut_len :]
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, : bench_args.cut_len
]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
return reqs
def prepare_synthetic_inputs_for_latency_test(
batch_size, input_len, custom_inputs=None
):
input_ids = (
custom_inputs
if custom_inputs
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
)
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
)
reqs = []
for i in range(len(input_ids)):
req = Req(
rid=i,
origin_input_text="",
origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req)
return reqs
@torch.no_grad
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
tree_cache=None,
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
)
batch.prepare_for_extend()
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits, batch
@torch.no_grad
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
if require_mlp_sync(model_runner.server_args):
Scheduler.prepare_mlp_sync_batch_raw(
batch,
dp_size=model_runner.server_args.dp_size,
attn_tp_size=1,
tp_group=model_runner.tp_group,
get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
)
def _read_prompts_from_file(prompt_file, rank_print):
"""Read custom prompts from the file specified by `--prompt-filename`."""
if not prompt_file:
return []
if not os.path.exists(prompt_file):
rank_print(
f"Custom prompt file {prompt_file} not found. Using default inputs..."
)
return []
with open(prompt_file, "r") as pf:
return pf.readlines()
def _save_profile_trace_results(profiler, filename):
parent_dir = os.path.dirname(os.path.abspath(filename))
os.makedirs(parent_dir, exist_ok=True)
profiler.export_chrome_trace(filename)
print(
profiler.key_averages(group_by_input_shape=True).table(
sort_by="self_cpu_time_total"
)
)
def correctness_test(
server_args,
port_args,
bench_args,
tp_rank,
):
# Configure the logger
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
input_ids, reqs = prepare_inputs_for_correctness_test(
bench_args, tokenizer, custom_prompts
)
rank_print(f"\n{input_ids=}\n")
if bench_args.cut_len > 0:
# Prefill
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (first half): {next_token_logits} \n")
# Prepare extend inputs
reqs = prepare_extend_inputs_for_correctness_test(
bench_args, input_ids, reqs, model_runner
)
# Extend (prefill w/ KV cache)
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
rank_print(f"prefill logits (final): {next_token_logits} \n")
# Decode
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
next_token_ids_list = next_token_ids.tolist()
for i in range(len(reqs)):
output_ids[i].append(next_token_ids_list[i])
# Print output texts
for i in range(len(reqs)):
rank_print(f"========== Prompt {i} ==========")
rank_print(tokenizer.decode(output_ids[i]), "\n")
def synchronize(device):
torch.get_device_module(device).synchronize()
def latency_test_run_once(
run_name,
model_runner,
rank_print,
reqs,
batch_size,
input_len,
output_len,
device,
log_decode_step,
profile,
profile_record_shapes,
profile_filename_prefix,
):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size:
rank_print(
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
)
return
# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool_allocator.clear()
measurement_results = {
"run_name": run_name,
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
}
tot_latency = 0
profiler = None
if profile:
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=profile_record_shapes,
)
profiler.start()
# Prefill
synchronize(device)
tic = time.perf_counter()
next_token_ids, _, batch = extend(reqs, model_runner)
synchronize(device)
prefill_latency = time.perf_counter() - tic
tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
if profile:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for prefill saved to {profile_filename}"
)
# Decode
decode_latencies = []
for i in range(output_len - 1):
synchronize(device)
if profile and i == output_len / 2:
profiler = None
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=profile_record_shapes,
)
profiler.start()
tic = time.perf_counter()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device)
latency = time.perf_counter() - tic
tot_latency += latency
throughput = batch_size / latency
decode_latencies.append(latency)
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
rank_print(
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
if profile and i == output_len / 2:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
)
# Record decode timing from 2nd output
if output_len > 1:
med_decode_latency = np.median(decode_latencies)
med_decode_throughput = batch_size / med_decode_latency
rank_print(
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
)
measurement_results["median_decode_latency"] = med_decode_latency
measurement_results["median_decode_throughput"] = med_decode_throughput
throughput = (input_len + output_len) * batch_size / tot_latency
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["total_latency"] = tot_latency
measurement_results["overall_throughput"] = throughput
return measurement_results
def latency_test(
server_args,
port_args,
bench_args,
tp_rank,
):
initialize_moe_config(server_args)
# Set CPU affinity
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank)
# Configure the logger
configure_logger(server_args, prefix=f" TP{tp_rank}")
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test(
bench_args.batch_size[0], bench_args.input_len[0]
)
# Warm up
rank_print("Warmup ...")
latency_test_run_once(
bench_args.run_name,
model_runner,
rank_print,
reqs,
bench_args.batch_size[0],
bench_args.input_len[0],
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
server_args.device,
log_decode_step=0,
profile=False,
profile_record_shapes=False,
profile_filename_prefix="", # not used
)
rank_print("Benchmark ...")
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
custom_input_len = len(custom_inputs)
# Run the sweep
result_list = []
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
bs_aligned_inputs = []
if custom_inputs:
if custom_input_len == bs:
bs_aligned_inputs = custom_inputs
elif custom_input_len > bs:
rank_print(
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
f"Using the first {bs} prompts."
)
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
else:
rank_print(
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
f"Pad to the desired batch_size with the last prompt."
)
bs_aligned_inputs = copy.deepcopy(custom_inputs)
bs_aligned_inputs.extend(
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
)
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
ret = latency_test_run_once(
bench_args.run_name,
model_runner,
rank_print,
reqs,
bs,
il,
ol,
server_args.device,
bench_args.log_decode_step,
bench_args.profile if tp_rank == 0 else None,
bench_args.profile_record_shapes if tp_rank == 0 else None,
bench_args.profile_filename_prefix,
)
if ret is not None:
result_list.append(ret)
# Write results in jsonlines format on rank 0.
if tp_rank == 0 and bench_args.result_filename:
with open(bench_args.result_filename, "a") as fout:
for result in result_list:
fout.write(json.dumps(result) + "\n")
if server_args.tp_size > 1:
destroy_distributed_environment()
def main(server_args, bench_args):
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
_set_envs_and_config(server_args)
if server_args.model_path:
if bench_args.correctness_test:
work_func = correctness_test
else:
work_func = latency_test
else:
raise ValueError(
"Provide --model-path for running the tests or "
"provide --result-filename for plotting the results"
)
port_args = PortArgs.init_new(server_args)
if server_args.tp_size == 1:
work_func(server_args, port_args, bench_args, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
port_args,
bench_args,
tp_rank,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
proc.terminate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
logging.basicConfig(
level=getattr(logging, server_args.log_level.upper()),
format="%(message)s",
)
try:
main(server_args, bench_args)
finally:
if server_args.tp_size != 1:
kill_process_tree(os.getpid(), include_parent=False)

View File

@@ -0,0 +1,439 @@
"""
Benchmark the latency of running a single batch with a server.
This script launches a server and uses the HTTP interface.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
Usage:
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
"""
import argparse
import dataclasses
import itertools
import json
import multiprocessing
import os
import time
from typing import List, Tuple
import requests
from sglang.bench_serving import get_tokenizer, sample_random_requests
from sglang.profiler import run_profile
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary
@dataclasses.dataclass
class BenchArgs:
run_name: str = "default"
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
temperature: float = 0.0
return_logprob: bool = False
client_stream_interval: int = 1
input_len_step_percentage: float = 0.0
result_filename: str = "result.jsonl"
base_url: str = ""
skip_warmup: bool = False
show_report: bool = False
profile: bool = False
profile_steps: int = 3
profile_by_stage: bool = False
dataset_path: str = ""
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
parser.add_argument(
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
)
parser.add_argument(
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
)
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--client-stream-interval",
type=int,
default=BenchArgs.client_stream_interval,
)
parser.add_argument(
"--input-len-step-percentage",
type=float,
default=BenchArgs.input_len_step_percentage,
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true")
parser.add_argument("--show-report", action="store_true")
parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--profile-steps", type=int, default=BenchArgs.profile_steps
)
parser.add_argument("--profile-by-stage", action="store_true")
parser.add_argument(
"--dataset-path",
type=str,
default=BenchArgs.dataset_path,
help="Path to the dataset.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def launch_server_internal(server_args):
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_process_tree(os.getpid(), include_parent=False)
def launch_server_process(server_args: ServerArgs):
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
proc.start()
base_url = f"http://{server_args.host}:{server_args.port}"
timeout = 600
start_time = time.time()
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
}
response = requests.get(f"{base_url}/v1/models", headers=headers)
if response.status_code == 200:
return proc, base_url
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.")
def run_one_case(
url: str,
batch_size: int,
input_len: int,
output_len: int,
temperature: float,
return_logprob: bool,
stream_interval: int,
input_len_step_percentage: float,
run_name: str,
result_filename: str,
tokenizer,
profile: bool = False,
profile_steps: int = 3,
profile_by_stage: bool = False,
dataset_path: str = "",
):
requests.post(url + "/flush_cache")
input_requests = sample_random_requests(
input_len=input_len,
output_len=output_len,
num_prompts=batch_size,
range_ratio=1.0,
tokenizer=tokenizer,
dataset_path=dataset_path,
random_sample=True,
return_text=False,
)
use_structured_outputs = False
if use_structured_outputs:
texts = []
for _ in range(batch_size):
texts.append(
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
* 50
+ "Assistant:"
)
json_schema = "$$ANY$$"
else:
json_schema = None
profile_link = None
if profile:
profile_link: str = run_profile(
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
)
tic = time.perf_counter()
response = requests.post(
url + "/generate",
json={
"input_ids": [req.prompt for req in input_requests],
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
"json_schema": json_schema,
"stream_interval": stream_interval,
},
"return_logprob": return_logprob,
"stream": True,
},
stream=True,
)
# The TTFT of the last request in the batch
ttft = 0.0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
if "error" in data:
raise RuntimeError(f"Request has failed. {data}.")
assert (
data["meta_info"]["finish_reason"] is None
or data["meta_info"]["finish_reason"]["type"] == "length"
)
if data["meta_info"]["completion_tokens"] == 1:
ttft = time.perf_counter() - tic
latency = time.perf_counter() - tic
input_throughput = batch_size * input_len / ttft
output_throughput = batch_size * output_len / (latency - ttft)
overall_throughput = batch_size * (input_len + output_len) / latency
server_info = requests.get(url + "/get_server_info").json()
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
print(f"batch size: {batch_size}")
print(f"input_len: {input_len}")
print(f"output_len: {output_len}")
print(f"latency: {latency:.2f} s")
print(f"ttft: {ttft:.2f} s")
print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"input throughput: {input_throughput:.2f} tok/s")
if output_len != 1:
print(f"output throughput: {output_throughput:.2f} tok/s")
if result_filename:
with open(result_filename, "a") as fout:
res = {
"run_name": run_name,
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
"latency": round(latency, 4),
"output_throughput": round(output_throughput, 2),
"overall_throughput": round(overall_throughput, 2),
"last_gen_throughput": round(last_gen_throughput, 2),
}
fout.write(json.dumps(res) + "\n")
return (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
profile_link if profile else None,
)
def get_report_summary(
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
):
import tabulate
summary = (
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
)
headers = [
"batch size",
"latency (s)",
"input throughput (tok/s)",
"output throughput (tok/s)",
"acc length",
"ITL (ms)",
"input cost ($/1M)",
"output cost ($/1M)",
]
if bench_args.profile:
headers.append("profile")
rows = []
for (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
_,
_,
acc_length,
trace_link,
) in result:
if is_blackwell():
hourly_cost_per_gpu = 4 # $4/hour for one B200
else:
hourly_cost_per_gpu = 2 # $2/hour for one H100
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
input_util = 0.7
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
itl = 1 / (output_throughput / batch_size) * 1000
input_cost = 1e6 / (input_throughput * input_util) / 3600 * hourly_cost
output_cost = 1e6 / output_throughput / 3600 * hourly_cost
row = [
batch_size,
latency,
input_throughput,
output_throughput,
accept_length,
itl,
input_cost,
output_cost,
]
if trace_link:
row.append(f"[Profile]({trace_link})")
rows.append(row)
summary += tabulate.tabulate(
rows, headers=headers, tablefmt="github", floatfmt=".2f"
)
return summary
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if bench_args.base_url:
proc, base_url = None, bench_args.base_url
else:
proc, base_url = launch_server_process(server_args)
server_info = requests.get(base_url + "/get_server_info").json()
if "tokenizer_path" in server_info:
tokenizer_path = server_info["tokenizer_path"]
elif "prefill" in server_info:
tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
tokenizer = get_tokenizer(tokenizer_path)
# warmup
if not bench_args.skip_warmup:
print("=" * 8 + " Warmup Begin " + "=" * 8)
run_one_case(
base_url,
batch_size=16,
input_len=1024,
output_len=16,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name="",
result_filename="",
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
)
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
# benchmark
result = []
bench_result = []
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
result.append(
run_one_case(
base_url,
bs,
il,
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
)
)
if bench_args.profile:
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
bench_result.append(
(
run_one_case(
base_url,
bs,
il,
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
profile=bench_args.profile,
profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage,
)[-1],
)
)
result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
except Exception as e:
print(f"Error profiling, there will be no profile trace dump: {e}")
finally:
if proc:
kill_process_tree(proc.pid)
print(f"\nResults are saved to {bench_args.result_filename}")
if not bench_args.show_report:
return
summary = get_report_summary(result, server_args, bench_args)
print(summary)
if is_in_ci():
write_github_step_summary(summary)
def main():
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
bench_args = BenchArgs.from_cli_args(args)
run_benchmark(server_args, bench_args)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

305
python/sglang/check_env.py Normal file
View File

@@ -0,0 +1,305 @@
"""Check environment configurations and dependency versions."""
import importlib.metadata
import os
import resource
import subprocess
import sys
from collections import OrderedDict, defaultdict
import torch
from sglang.srt.utils import is_hip
def is_cuda_v2():
return torch.version.cuda is not None
# List of packages to check versions
PACKAGE_LIST = [
"sglang",
"sgl_kernel",
"flashinfer_python",
"triton",
"transformers",
"torchao",
"numpy",
"aiohttp",
"fastapi",
"hf_transfer",
"huggingface_hub",
"interegular",
"modelscope",
"orjson",
"outlines",
"packaging",
"psutil",
"pydantic",
"python-multipart",
"pyzmq",
"torchao",
"uvicorn",
"uvloop",
"vllm",
"xgrammar",
"openai",
"tiktoken",
"anthropic",
"litellm",
"decord",
]
def get_package_versions(packages):
"""
Get versions of specified packages.
"""
versions = {}
for package in packages:
package_name = package.split("==")[0].split(">=")[0].split("<=")[0]
try:
version = importlib.metadata.version(package_name)
versions[package_name] = version
except ModuleNotFoundError:
versions[package_name] = "Module Not Found"
return versions
def get_cuda_info():
"""
Get CUDA-related information if available.
"""
if is_cuda_v2():
cuda_info = {"CUDA available": torch.cuda.is_available()}
if cuda_info["CUDA available"]:
cuda_info.update(_get_gpu_info())
cuda_info.update(_get_cuda_version_info())
return cuda_info
elif is_hip():
cuda_info = {"ROCM available": torch.cuda.is_available()}
if cuda_info["ROCM available"]:
cuda_info.update(_get_gpu_info())
cuda_info.update(_get_cuda_version_info())
return cuda_info
def _get_gpu_info():
"""
Get information about available GPUs.
"""
devices = defaultdict(list)
capabilities = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
capability = torch.cuda.get_device_capability(k)
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
gpu_info = {}
for name, device_ids in devices.items():
gpu_info[f"GPU {','.join(device_ids)}"] = name
if len(capabilities) == 1:
# All GPUs have the same compute capability
cap, gpu_ids = list(capabilities.items())[0]
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
else:
# GPUs have different compute capabilities
for cap, gpu_ids in capabilities.items():
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
return gpu_info
def _get_cuda_version_info():
"""
Get CUDA version information.
"""
if is_cuda_v2():
from torch.utils.cpp_extension import CUDA_HOME
cuda_info = {"CUDA_HOME": CUDA_HOME}
if CUDA_HOME and os.path.isdir(CUDA_HOME):
cuda_info.update(_get_nvcc_info())
cuda_info.update(_get_cuda_driver_version())
return cuda_info
elif is_hip():
from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME
cuda_info = {"ROCM_HOME": ROCM_HOME}
if ROCM_HOME and os.path.isdir(ROCM_HOME):
cuda_info.update(_get_nvcc_info())
cuda_info.update(_get_cuda_driver_version())
return cuda_info
else:
cuda_info = {"CUDA_HOME": ""}
return cuda_info
def _get_nvcc_info():
"""
Get NVCC version information.
"""
if is_cuda_v2():
from torch.utils.cpp_extension import CUDA_HOME
try:
nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
nvcc_output = (
subprocess.check_output(f'"{nvcc}" -V', shell=True)
.decode("utf-8")
.strip()
)
return {
"NVCC": nvcc_output[
nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind(
"Build"
)
].strip()
}
except subprocess.SubprocessError:
return {"NVCC": "Not Available"}
elif is_hip():
from torch.utils.cpp_extension import ROCM_HOME
try:
hipcc = os.path.join(ROCM_HOME, "bin/hipcc")
hipcc_output = (
subprocess.check_output(f'"{hipcc}" --version', shell=True)
.decode("utf-8")
.strip()
)
return {
"HIPCC": hipcc_output[
hipcc_output.rfind("HIP version") : hipcc_output.rfind("AMD clang")
].strip()
}
except subprocess.SubprocessError:
return {"HIPCC": "Not Available"}
else:
return {"NVCC": "Not Available"}
def _get_cuda_driver_version():
"""
Get CUDA driver version.
"""
versions = set()
if is_cuda_v2():
try:
output = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=driver_version",
"--format=csv,noheader,nounits",
]
)
versions = set(output.decode().strip().split("\n"))
if len(versions) == 1:
return {"CUDA Driver Version": versions.pop()}
else:
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
except subprocess.SubprocessError:
return {"CUDA Driver Version": "Not Available"}
elif is_hip():
try:
output = subprocess.check_output(
[
"rocm-smi",
"--showdriverversion",
"--csv",
]
)
versions = set(output.decode().strip().split("\n"))
versions.discard("name, value")
ver = versions.pop()
ver = ver.replace('"Driver version", ', "").replace('"', "")
return {"ROCM Driver Version": ver}
except subprocess.SubprocessError:
return {"ROCM Driver Version": "Not Available"}
else:
return {"CUDA Driver Version": "Not Available"}
def get_gpu_topology():
"""
Get GPU topology information.
"""
if is_cuda_v2():
try:
result = subprocess.run(
["nvidia-smi", "topo", "-m"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
return "\n" + result.stdout if result.returncode == 0 else None
except subprocess.SubprocessError:
return None
elif is_hip():
try:
result = subprocess.run(
["rocm-smi", "--showtopotype"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True,
)
return "\n" + result.stdout if result.returncode == 0 else None
except subprocess.SubprocessError:
return None
else:
return None
def get_hypervisor_vendor():
try:
output = subprocess.check_output(["lscpu"], text=True)
for line in output.split("\n"):
if "Hypervisor vendor:" in line:
return line.split(":")[1].strip()
return None
except:
return None
def check_env():
"""
Check and print environment information.
"""
env_info = OrderedDict()
env_info["Python"] = sys.version.replace("\n", "")
env_info.update(get_cuda_info())
env_info["PyTorch"] = torch.__version__
env_info.update(get_package_versions(PACKAGE_LIST))
gpu_topo = get_gpu_topology()
if gpu_topo:
if is_cuda_v2():
env_info["NVIDIA Topology"] = gpu_topo
elif is_hip():
env_info["AMD Topology"] = gpu_topo
hypervisor_vendor = get_hypervisor_vendor()
if hypervisor_vendor:
env_info["Hypervisor vendor"] = hypervisor_vendor
ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
env_info["ulimit soft"] = ulimit_soft
for k, v in env_info.items():
print(f"{k}: {v}")
if __name__ == "__main__":
check_env()

View File

@@ -0,0 +1,184 @@
"""
Compile DeepGEMM Kernels for a model with specify server arguments
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
It accepts server arguments (the same as launch_server.py).
Usage:
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
"""
import argparse
import dataclasses
import multiprocessing
import os
import time
import requests
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
from sglang.srt.warmup import warmup
multiprocessing.set_start_method("spawn", force=True)
# Reduce warning
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
# Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
@dataclasses.dataclass
class CompileArgs:
timeout: int = 3600
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
@warmup("compile-deep-gemm")
async def warm_up_compile(
disaggregation_mode: str, tokenizer_manager: TokenizerManager
):
print("\nGenerate warm up request for compiling DeepGEMM...\n")
generate_req_input = GenerateReqInput(
input_ids=[0, 1, 2, 3],
sampling_params={
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": True,
},
)
if disaggregation_mode != "null":
generate_req_input.bootstrap_room = 0
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
def launch_server_internal(server_args):
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_process_tree(os.getpid(), include_parent=False)
def launch_server_process_and_send_one_request(
server_args: ServerArgs, compile_args: CompileArgs
):
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
proc.start()
base_url = f"http://{server_args.host}:{server_args.port}"
timeout = compile_args.timeout
start_time = time.perf_counter()
while time.perf_counter() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
}
if server_args.node_rank == 0:
response = requests.get(f"{base_url}/v1/models", headers=headers)
else:
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
response = requests.get(f"{base_url}/health", headers=headers)
if response.status_code == 200:
# Rank-0 node send a request to sync with other node and then return.
if server_args.node_rank == 0:
response = requests.post(
f"{base_url}/generate",
json={
"input_ids": [0, 1, 2, 3],
"sampling_params": {
"max_new_tokens": 8,
"temperature": 0,
},
},
timeout=600,
)
if response.status_code != 200:
error = response.json()
raise RuntimeError(f"Sync request failed: {error}")
# Other nodes should wait for the exit signal from Rank-0 node.
else:
start_time_waiting = time.perf_counter()
while proc.is_alive():
if time.perf_counter() - start_time_waiting < timeout:
time.sleep(10)
else:
raise TimeoutError("Waiting for main node timeout!")
return proc
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError(
"DeepGEMM Kernels compilation timeout."
"\n\nFeel free and please restart the command."
)
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
# Disable cuda graph and torch compile to save time
server_args.disable_cuda_graph = True
server_args.enable_torch_compile = False
print(f"Disable CUDA Graph and Torch Compile to save time...")
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
server_args.watchdog_timeout = compile_args.timeout
server_args.warmups = "compile-deep-gemm"
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
print(
"Begin DeepGEMM Kernels compilation...\n"
"It may take a long time and timeout maybe raised "
"while the compilation is still in progress.\n"
"Just feel free to restart the command "
"until the compilation is fully finished.\n"
)
proc = launch_server_process_and_send_one_request(server_args, compile_args)
print("\nDeepGEMM Kernels compilation finished successfully.")
# Sleep for safety
time.sleep(10)
if proc.is_alive():
# This is the rank0 node.
kill_process_tree(proc.pid)
else:
try:
kill_process_tree(proc.pid)
except Exception:
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
CompileArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
compile_args = CompileArgs.from_cli_args(args)
refine_server_args(server_args, compile_args)
run_compile(server_args, compile_args)

View File

@@ -0,0 +1,315 @@
# Adapt from https://github.com/fw-ai/llm_eval_meta
import argparse
import asyncio
import os
import pickle
import re
import shutil
from collections import defaultdict
from dataclasses import dataclass
import httpx
import numpy as np
import openai
from datasets import load_dataset
from openai import AsyncOpenAI
from tqdm import tqdm
# Mapping providers to their clients and models
provider_to_models = {
"b10": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
"oai": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
"sgl": {
"8b": "meta-llama/Llama-3.1-8B-Instruct",
"70b": "meta-llama/Llama-3.1-70B-Instruct",
"405b": "meta-llama/Llama-3.1-405B-Instruct",
},
}
async def fetch_responses(
client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
):
output_file = os.path.join(output_dir, f"response_{index}.pkl")
if os.path.exists(output_file):
print(f"File {output_file} already exists, skipping.")
return
async with semaphore:
response = await client.completions.create(
model=provider_to_models[provider][model_size],
prompt=prompt,
temperature=0.0,
max_tokens=max_tokens,
)
if isinstance(response, openai.BadRequestError):
with open(output_file, "wb") as f:
pickle.dump("bad_response", f)
assert isinstance(response, openai.types.completion.Completion)
# Save response to a file
with open(output_file, "wb") as f:
pickle.dump(response, f)
TASK_TO_MAX_TOKENS = {
"evals__mmlu__details": 1,
"evals__mmlu__0_shot__cot__details": 1024,
# Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
"evals__mmlu_pro__details": 2048,
"evals__gsm8k__details": 1024,
}
TASK_TO_EVAL_SET = {
"mmlu": "evals__mmlu__details",
"mmlu_cot": "evals__mmlu__0_shot__cot__details",
"mmlu_pro": "evals__mmlu_pro__details",
"gsm8k": "evals__gsm8k__details",
}
class CustomAsyncHTTPXClient(httpx.AsyncClient):
async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
request.url = httpx.URL(
f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
)
return await super().send(request, *args, **kwargs)
def get_client(provider):
if provider not in "b10":
if os.getenv("OPENAI_API_KEY") == None:
os.environ["OPENAI_API_KEY"] = "EMPTY"
return {
"oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
"b10": AsyncOpenAI(
api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
http_client=CustomAsyncHTTPXClient(),
),
"sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
}[provider]
# Define the benchmark function
async def benchmark(args):
ds = load_dataset(
"meta-llama/Llama-3.1-405B-Instruct-evals",
f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
)
semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
if args.num_examples is None:
args.num_examples = len(ds["latest"]["input_final_prompts"])
prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
# Create the output directory if it does not exist
os.makedirs(args.output_dir, exist_ok=True)
tasks = []
# Create the tasks with tqdm progress bar
max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
client = get_client(args.provider)
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
tasks.append(
asyncio.create_task(
fetch_responses(
client,
f"<|begin_of_text|>{prompt[0]}",
semaphore,
idx,
args.provider,
args.model_size,
args.output_dir,
max_tokens=max_tokens,
)
)
)
# Run the tasks with tqdm progress bar
for future in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
):
await future
def get_mmlu_answer(response):
if response is not None:
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
return None
def get_mmlu_cot_answer(response):
pattern = r"The best answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "").replace("*", "")
pattern = r"the best answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
pattern = r"The correct answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
pattern = r"the correct answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
return match.group(1).replace(".", "")
def get_answer_gsm8k(response):
pattern = r"The final answer is (.+)\.?"
match = re.search(pattern, response.choices[0].text)
if match:
s = match.group(1)
for ok_symbol in ["%", "$"]:
s = s.replace(ok_symbol, "")
return s
TASK_TO_ANSWER_EXTRACTOR = {
"evals__mmlu__details": get_mmlu_answer,
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
"evals__gsm8k__details": get_answer_gsm8k,
"evals__mmlu_pro__details": get_mmlu_cot_answer,
}
def get_dataset_from_task(task, response_path, model_size):
ds_405b = load_dataset(
f"meta-llama/Llama-3.1-405B-Instruct-evals",
f"Llama-3.1-405B-Instruct-{task}",
)
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
if "70b" in model_size or "8b" in model_size:
if "70" in model_size:
ref_model_ds = load_dataset(
f"meta-llama/Llama-3.1-70B-Instruct-evals",
f"Llama-3.1-70B-Instruct-{task}",
)
else:
ref_model_ds = load_dataset(
f"meta-llama/Llama-3.1-8B-Instruct-evals",
f"Llama-3.1-8B-Instruct-{task}",
)
hash_to_row = {}
for row in ref_model_ds["latest"]:
hash_to_row[row["input_final_prompts_hash"][0]] = row
reordered_rows = []
for prompt_hash in ds_405b_hash_order:
reordered_rows.append(hash_to_row[prompt_hash])
ref_model_ds["latest"] = reordered_rows
return ref_model_ds
return ds_405b
def analyze(task, response_path, model_size):
ds = get_dataset_from_task(task, response_path, model_size)
responses = []
total = len(ds["latest"])
for i in range(0, total):
response = pickle.load(
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
)
responses.append(response)
@dataclass
class Stats:
correct: int = 0
total: int = 0
meta_correct: int = 0
average: float = None
subtask_name_to_stats = defaultdict(lambda: Stats())
for response, ds_row in zip(responses, ds["latest"]):
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
subtask = ds_row["subtask_name"]
is_eval_correct = model_answer in ds_row["input_correct_responses"]
if is_eval_correct:
subtask_name_to_stats[subtask].correct += 1
if ds_row["is_correct"]:
subtask_name_to_stats[subtask].meta_correct += 1
subtask_name_to_stats[subtask].total += 1
micro_stats = Stats()
for subtask, stats in subtask_name_to_stats.items():
stats.average = stats.correct / stats.total
stats.meta_average = stats.meta_correct / stats.total
micro_stats.correct += stats.correct
micro_stats.total += stats.total
micro_stats.meta_correct += stats.meta_correct
micro_stats.average = micro_stats.correct / micro_stats.total
micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
print(
"Meta Macro average",
np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
)
print("Micro average", micro_stats.average)
print("Meta Micro average", micro_stats.meta_average)
# Entry point for the script
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script to run model with specified parameters."
)
parser.add_argument(
"--model-size",
type=str,
default="8b",
help="Size of the model (e.g., 8b or 70b)",
)
parser.add_argument(
"--provider",
type=str,
default="sgl",
help="Provider name (e.g., sgl, oai, b10)",
)
parser.add_argument(
"--task",
type=str,
required=True,
help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
)
parser.add_argument(
"--num-examples", type=int, default=None, help="Number of examples to process"
)
parser.add_argument("--concurrency", type=int, default=16)
parser.add_argument(
"--output-dir",
type=str,
default="tmp-output-dir",
help="Directory to save responses",
)
args = parser.parse_args()
asyncio.run(benchmark(args))
analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
shutil.rmtree("tmp-output-dir", ignore_errors=True)

View File

@@ -0,0 +1,164 @@
import argparse
import asyncio
import os
import pickle
from pathlib import Path
from typing import List
import openai
import torch
from bert_score import BERTScorer
from datasets import load_dataset
from tqdm import tqdm
def get_client(api_url: str) -> openai.AsyncOpenAI:
if os.getenv("OPENAI_API_KEY") is None:
os.environ["OPENAI_API_KEY"] = "EMPTY"
return openai.AsyncOpenAI(base_url=api_url)
def get_dataset():
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
async def fetch_response(
client: openai.AsyncOpenAI,
context: str,
question: str,
semaphore: asyncio.Semaphore,
index: int,
model: str,
output_dir: Path,
):
output_file = output_dir / f"response_{index}.pkl"
if output_file.exists():
return
prompt = (
"Please answer the question based on the long texts below.\n"
f"{context}\n"
f"Question: {question}\n"
"Answer:"
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
async with semaphore:
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=0.0,
max_tokens=512,
)
except openai.BadRequestError as e:
with open(output_file, "wb") as f:
pickle.dump({"error": str(e)}, f)
return
with open(output_file, "wb") as f:
pickle.dump(response, f)
async def benchmark(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
client = get_client(args.api_url)
semaphore = asyncio.Semaphore(args.max_concurrency)
tasks: List[asyncio.Task] = []
for idx, ex in enumerate(dataset):
if idx >= args.num_prompts:
break
tasks.append(
asyncio.create_task(
fetch_response(
client,
ex["context"],
ex["question"],
semaphore,
idx,
args.model,
output_dir,
)
)
)
for _ in tqdm(
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
):
await _
def analyse(args):
dataset = get_dataset()
output_dir = Path(args.output_dir)
device = "cuda" if torch.cuda.is_available() else "cpu"
scorer = BERTScorer(lang="en", device=device)
hyps: List[str] = []
refs: List[str] = []
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
if idx >= args.num_prompts:
break
pkl_file = output_dir / f"response_{idx}.pkl"
if not pkl_file.exists():
raise FileNotFoundError(pkl_file)
response = pickle.load(open(pkl_file, "rb"))
if isinstance(response, dict) and "error" in response:
continue
hyps.append(response.choices[0].message.content.strip())
refs.append(ex["answer"])
if not hyps:
print("No valid responses to score!")
return
batch_size = 64
all_f1: List[float] = []
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
h_batch = hyps[i : i + batch_size]
r_batch = refs[i : i + batch_size]
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
all_f1.extend([float(x) for x in f1_scores])
avg = sum(all_f1) / len(all_f1)
print(f"Average BERTScore (F1): {avg:.2%}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run benchmark and evaluation in one go."
)
parser.add_argument(
"--api-url",
default="http://127.0.0.1:30000/v1",
help="OpenAIcompatible API base URL",
)
parser.add_argument(
"--model",
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
help="Model name or ID, only used for model name",
)
parser.add_argument(
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
)
parser.add_argument(
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
)
parser.add_argument(
"--num-prompts", type=int, default=10000, help="Number of prompts to run"
)
args = parser.parse_args()
asyncio.run(benchmark(args))
analyse(args)

View File

@@ -0,0 +1,53 @@
"""Global configurations"""
import os
class GlobalConfig:
"""
Store some global constants.
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
many global runtime arguments as well.
"""
def __init__(self):
# Verbosity level
# 0: do not output anything
# 2: output final text after every run
self.verbosity = 0
# Default backend of the language
self.default_backend = None
# Runtime constants: New generation token ratio estimation
self.default_init_new_token_ratio = float(
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
)
self.default_min_new_token_ratio_factor = float(
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
)
self.default_new_token_ratio_decay_steps = float(
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
)
self.torch_empty_cache_interval = float(
os.environ.get(
"SGLANG_EMPTY_CACHE_INTERVAL", -1
) # in seconds. Set if you observe high memory accumulation over a long serving period.
)
# Runtime constants: others
self.retract_decode_steps = 20
self.flashinfer_workspace_size = os.environ.get(
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
)
# Output tokenization configs
self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True
# Language frontend interpreter optimization configs
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True
global_config = GlobalConfig()

286
python/sglang/lang/api.py Normal file
View File

@@ -0,0 +1,286 @@
"""Public APIs of the language."""
import re
from typing import Callable, List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
from sglang.lang.ir import (
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglSeparateReasoning,
SglVideo,
)
def function(
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
):
if func:
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
def decorator(func):
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
return decorator
def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency
from sglang.lang.backend.runtime_endpoint import Runtime
return Runtime(*args, **kwargs)
def Engine(*args, **kwargs):
# Avoid importing unnecessary dependency
from sglang.srt.entrypoints.engine import Engine
return Engine(*args, **kwargs)
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend
def flush_cache(backend: Optional[BaseBackend] = None):
backend = backend or global_config.default_backend
if backend is None:
return False
# If backend is Runtime
if hasattr(backend, "endpoint"):
backend = backend.endpoint
return backend.flush_cache()
def get_server_info(backend: Optional[BaseBackend] = None):
backend = backend or global_config.default_backend
if backend is None:
return None
# If backend is Runtime
if hasattr(backend, "endpoint"):
backend = backend.endpoint
return backend.get_server_info()
def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[Union[type, str]] = None,
choices: Optional[List[str]] = None,
choices_method: Optional[ChoicesSamplingMethod] = None,
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
if choices:
return SglSelect(
name,
choices,
0.0 if temperature is None else temperature,
token_length_normalized if choices_method is None else choices_method,
)
# check regex is valid
if regex is not None:
try:
re.compile(regex)
except re.error as e:
raise e
return SglGen(
name,
max_tokens,
min_tokens,
n,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
dtype,
regex,
json_schema,
)
def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
):
return SglGen(
name,
max_tokens,
None,
n,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
int,
None,
)
def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
):
return SglGen(
name,
max_tokens,
None,
n,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
str,
None,
)
def image(expr: SglExpr):
return SglImage(expr)
def video(path: str, num_frames: int):
return SglVideo(path, num_frames)
def select(
name: Optional[str] = None,
choices: Optional[List[str]] = None,
temperature: float = 0.0,
choices_method: ChoicesSamplingMethod = token_length_normalized,
):
assert choices is not None
return SglSelect(name, choices, temperature, choices_method)
def _role_common(name: str, expr: Optional[SglExpr] = None):
if expr is None:
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
else:
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
def system(expr: Optional[SglExpr] = None):
return _role_common("system", expr)
def user(expr: Optional[SglExpr] = None):
return _role_common("user", expr)
def assistant(expr: Optional[SglExpr] = None):
return _role_common("assistant", expr)
def system_begin():
return SglRoleBegin("system")
def system_end():
return SglRoleEnd("system")
def user_begin():
return SglRoleBegin("user")
def user_end():
return SglRoleEnd("user")
def assistant_begin():
return SglRoleBegin("assistant")
def assistant_end():
return SglRoleEnd("assistant")
def separate_reasoning(
expr: Optional[SglExpr] = None, model_type: Optional[str] = None
):
return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])

View File

@@ -0,0 +1,73 @@
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import anthropic
except ImportError as e:
anthropic = e
class Anthropic(BaseBackend):
def __init__(self, model_name, *args, **kwargs):
super().__init__()
if isinstance(anthropic, Exception):
raise anthropic
self.model_name = model_name
self.chat_template = get_chat_template("claude")
self.client = anthropic.Anthropic(*args, **kwargs)
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
if messages and messages[0]["role"] == "system":
system = messages.pop(0)["content"]
else:
system = ""
ret = self.client.messages.create(
model=self.model_name,
system=system,
messages=messages,
**sampling_params.to_anthropic_kwargs(),
)
comp = ret.content[0].text
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
if messages and messages[0]["role"] == "system":
system = messages.pop(0)["content"]
else:
system = ""
with self.client.messages.stream(
model=self.model_name,
system=system,
messages=messages,
**sampling_params.to_anthropic_kwargs(),
) as stream:
for text in stream.text_stream:
yield text, {}

View File

@@ -0,0 +1,82 @@
from typing import List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
class BaseBackend:
def __init__(self) -> None:
self.support_concate_and_append = False
self.chat_template = get_chat_template("default")
def get_model_name(self):
raise NotImplementedError()
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
pass
def uncache_prefix(self, rid: str):
pass
def end_request(self, rid: Union[str, List[str]]):
pass
def begin_program(self, s: StreamExecutor):
pass
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
pass
def commit_lazy_operations(self, s: StreamExecutor):
pass
def fork_program(
self,
src: StreamExecutor,
dst: List[StreamExecutor],
position_ids_offset: Optional[List[int]] = None,
):
pass
def fill_image(self, s: StreamExecutor):
pass
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
choices_method: Optional[ChoicesSamplingMethod] = None,
) -> ChoicesDecision:
raise NotImplementedError()
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
raise NotImplementedError()
def shutdown(self):
pass
def flush_cache(self):
pass
def get_server_info(self):
pass

View File

@@ -0,0 +1,90 @@
from typing import Mapping, Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import litellm
except ImportError as e:
litellm = e
litellm.num_retries = 1
class LiteLLM(BaseBackend):
def __init__(
self,
model_name,
chat_template=None,
api_key=None,
organization: Optional[str] = None,
base_url: Optional[str] = None,
timeout: Optional[float] = 600,
max_retries: Optional[int] = litellm.num_retries,
default_headers: Optional[Mapping[str, str]] = None,
):
super().__init__()
if isinstance(litellm, Exception):
raise litellm
self.model_name = model_name
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name
)
self.client_params = {
"api_key": api_key,
"organization": organization,
"base_url": base_url,
"timeout": timeout,
"max_retries": max_retries,
"default_headers": default_headers,
}
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
ret = litellm.completion(
model=self.model_name,
messages=messages,
**self.client_params,
**sampling_params.to_litellm_kwargs(),
)
comp = ret.choices[0].message.content
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
ret = litellm.completion(
model=self.model_name,
messages=messages,
stream=True,
**self.client_params,
**sampling_params.to_litellm_kwargs(),
)
for chunk in ret:
text = chunk.choices[0].delta.content
if text is not None:
yield text, {}

View File

@@ -0,0 +1,475 @@
import dataclasses
import logging
import time
import warnings
from typing import List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import openai
import tiktoken
except ImportError as e:
openai = tiktoken = e
logger = logging.getLogger(__name__)
def create_logit_bias_int(tokenizer):
"""Get logit bias for integer numbers."""
int_token_ids = []
tokens = tokenizer._mergeable_ranks
for token, token_id in tokens.items():
s = tokenizer.decode([token_id])
if all([c.isdigit() for c in s]) or s in [" "]:
int_token_ids.append(token_id)
if len(int_token_ids) >= 300: # OpenAI API limit
break
special_tokens = tokenizer._special_tokens
mask = {t: 100 for t in int_token_ids[:299]}
mask[special_tokens["<|endoftext|>"]] = 100
return mask
INSTRUCT_MODEL_NAMES = [
"gpt-3.5-turbo-instruct",
]
@dataclasses.dataclass
class TokenUsage:
prompt_tokens: int
completion_tokens: int
def reset(self):
self.prompt_tokens = self.completion_tokens = 0
class OpenAI(BaseBackend):
def __init__(
self,
model_name: str,
is_chat_model: Optional[bool] = None,
chat_template: Optional[ChatTemplate] = None,
is_azure: bool = False,
*args,
**kwargs,
):
super().__init__()
if isinstance(openai, Exception):
raise openai
if is_azure:
self.client = openai.AzureOpenAI(*args, **kwargs)
else:
self.client = openai.OpenAI(*args, **kwargs)
self.model_name = model_name
try:
self.tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name
)
if is_chat_model is not None:
self.is_chat_model = is_chat_model
else:
if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
# Usage
self.token_usage = TokenUsage(0, 0)
# API speculative execution
# TODO(ying): This does not support multi-threading (run_batch)
self.spec_kwargs = {}
self.spec_format = []
self.spec_max_num_tries = 3
def get_chat_template(self):
return self.chat_template
def _prepare_spec_execution(
self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else:
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": spec_var_name}
)
return "", {}
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
spec_var_name: str = None,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if s.num_api_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
else:
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
)
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
if (
self.model_name.startswith("o1")
or self.model_name.startswith("o3")
or "o1" in self.model_name
):
kwargs.pop("max_tokens", None)
else:
kwargs.pop("max_completion_tokens", None)
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
# Keep the returned list (or string) as is.
elif sampling_params.dtype in [str, "str", "string"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_ + '"',
stop='"',
**kwargs,
)
# Wrap each element in quotes if we have a list.
if isinstance(comp, list):
comp = ['"' + x + '"' for x in comp]
else:
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
logit_bias=self.logit_bias_int,
stop=[" "],
**kwargs,
)
# Leave as a list if that's what is returned.
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
return comp, {}
def spec_fill(self, value: str):
assert self.is_chat_model
self.spec_format.append({"text": value, "stop": None, "name": None})
def spec_pattern_match(self, comp):
for i, term in enumerate(self.spec_format):
text = term["text"]
if text != "":
if comp.startswith(text):
comp = comp[len(text) :]
else:
return False
else:
pos = comp.find(term["stop"])
if pos != -1:
term["text"] = comp[:pos]
comp = comp[pos:]
else:
if i == len(self.spec_format) - 1:
term["text"] = comp
else:
return False
return True
def role_end_generate(
self,
s: StreamExecutor,
):
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
if not all(x["name"] is None for x in self.spec_format):
# TODO(ying): throw errors or warnings
for i in range(self.spec_max_num_tries):
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.messages_,
**self.spec_kwargs,
)
# Use a string for pattern matching.
comp_for_match = comp[0] if isinstance(comp, list) else comp
if self.spec_pattern_match(comp_for_match):
break
for term in self.spec_format:
s.text_ += term["text"]
name = term["name"]
if name is not None:
s.variables[name] = term["text"]
s.meta_info[name] = {}
s.variable_event[name].set()
self.spec_kwargs = {}
self.spec_format = []
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
return generator
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
"""Note: `choices_method` is not used by the OpenAI backend."""
if self.is_chat_model:
raise NotImplementedError(
"select/choices is not supported for chat models. "
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
)
n_choices = len(choices)
token_ids = [self.tokenizer.encode(x) for x in choices]
scores = [0] * n_choices
valid = [len(x) > 0 for x in token_ids]
prompt_tokens = self.tokenizer.encode(s.text_)
max_len = max([len(x) for x in token_ids])
for step in range(max_len):
# Build logit bias
logit_bias = {}
for i in range(n_choices):
if valid[i]:
logit_bias[token_ids[i][step]] = 100
# Call API
ret = self.client.completions.create(
model=self.model_name,
prompt=prompt_tokens,
logit_bias=logit_bias,
max_tokens=1,
temperature=temperature,
)
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens = ret.usage.completion_tokens
# TODO:
# 1. return logits as the scores
# 2. compute logits of the full choice
# 3. consider chunk-based decoding
# Update valid
hit = False
for i in range(n_choices):
if valid[i]:
if step == len(token_ids[i]) - 1:
valid[i] = False
if ret_token == token_ids[i][step]:
scores[i] += 1
hit = True
else:
valid[i] = False
assert hit
if np.sum(valid) <= 1:
break
prompt_tokens.append(ret_token)
return ChoicesDecision(
decision=choices[np.argmax(scores)],
meta_info={"scores": scores},
)
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
) -> Union[str, List[str]]:
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
ret = client.chat.completions.create(messages=prompt, **kwargs)
if len(ret.choices) == 1:
comp = ret.choices[0].message.content
else:
comp = [c.message.content for c in ret.choices]
else:
ret = client.completions.create(prompt=prompt, **kwargs)
if isinstance(prompt, (list, tuple)):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
if len(ret.choices) > 1:
comp = [c.text for c in ret.choices]
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
time.sleep(5)
if attempt == retries - 1:
raise e
except Exception as e:
logger.error(f"RuntimeError {e}.")
raise e
return comp
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
continue
try:
content = ret.choices[0].delta.content
except IndexError:
content = None
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
continue
content = ret.choices[0].text
yield content or "", {}
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
time.sleep(5)
if attempt == retries - 1:
raise e
except Exception as e:
logger.error(f"RuntimeError {e}.")
raise e

View File

@@ -0,0 +1,527 @@
import atexit
import json
import multiprocessing
import warnings
from typing import Dict, List, Optional, Union
import aiohttp
import requests
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import (
REGEX_BOOL,
REGEX_FLOAT,
REGEX_INT,
REGEX_STR,
SglSamplingParams,
)
from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend):
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
verify: Optional[str] = None,
chat_template_name: Optional[str] = None,
):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
self.api_key = api_key
self.verify = verify
res = http_request(
self.base_url + "/get_model_info",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
self.model_info = res.json()
if chat_template_name:
self.chat_template = get_chat_template(chat_template_name)
else:
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
def get_model_name(self):
return self.model_info["model_path"]
def flush_cache(self):
res = http_request(
self.base_url + "/flush_cache",
api_key=self.api_key,
verify=self.verify,
method="POST",
)
self._assert_success(res)
def get_server_info(self):
res = http_request(
self.base_url + "/get_server_info",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
return res.json()
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def start_profile(self):
res = http_request(
self.base_url + "/start_profile",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def stop_profile(self):
res = http_request(
self.base_url + "/stop_profile",
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def commit_lazy_operations(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
if sampling_params.dtype is None:
return
if sampling_params.stop == ():
sampling_params.stop = []
dtype_regex = None
if sampling_params.dtype in ["int", int]:
dtype_regex = REGEX_INT
sampling_params.stop.extend([" ", "\n"])
elif sampling_params.dtype in ["float", float]:
dtype_regex = REGEX_FLOAT
sampling_params.stop.extend([" ", "\n"])
elif sampling_params.dtype in ["str", str]:
dtype_regex = REGEX_STR
elif sampling_params.dtype in ["bool", bool]:
dtype_regex = REGEX_BOOL
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
if dtype_regex is not None and sampling_params.regex is not None:
warnings.warn(
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
)
sampling_params.regex = dtype_regex
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
self._handle_dtype_to_regex(sampling_params)
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
for item in [
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
self._handle_dtype_to_regex(sampling_params)
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
for item in [
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
data["stream"] = True
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
stream=True,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
pos = 0
for chunk in res.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text, meta_info
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
assert temperature <= 1e-5
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
obj = self._generate_http_request(s, data)
prompt_len = obj["meta_info"]["prompt_tokens"]
logprob_start_len = max(prompt_len - 2, 0) # For token healing
# Compute logprob
data = {
"text": [s.text_ + c for c in choices],
"sampling_params": {
"max_new_tokens": 0,
"temperature": 0,
},
"return_logprob": True,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
}
obj = self._generate_http_request(s, data)
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
normalized_prompt_logprobs = [
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
for r in obj
]
# Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)):
healed_token_str = input_token_logprobs[i][0][-1]
if s.text_.endswith(healed_token_str):
healed_token_logprob = input_token_logprobs[i][0][0]
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
- healed_token_logprob
) / (len(input_token_logprobs[i]) - 1)
input_token_logprobs[i] = input_token_logprobs[i][1:]
# Compute unconditional logprobs if required
if choices_method.requires_unconditional_logprobs:
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
data = {
"input_ids": input_ids,
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
}
obj = self._generate_http_request(s, data)
unconditional_token_logprobs = [
r["meta_info"]["input_token_logprobs"] for r in obj
]
else:
unconditional_token_logprobs = None
return choices_method(
choices=choices,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs,
output_token_logprobs=output_token_logprobs,
unconditional_token_logprobs=unconditional_token_logprobs,
)
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def _generate_http_request(self, s: StreamExecutor, data):
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
return res.json()
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]
def _assert_success(self, res):
if res.status_code != 200:
try:
content = res.json()
except json.JSONDecodeError:
content = res.text
raise RuntimeError(content)
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)
class Runtime:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the command line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing without the frontend language.
"""
def __init__(
self,
log_level: str = "error",
*args,
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
# client code without installing SRT server and its dependency if they want.
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# Pre-allocate ports
for port in range(self.server_args.port, 40000):
if is_port_available(port):
break
self.server_args.port = port
self.url = self.server_args.url()
self.generate_url = self.url + "/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
ctx = multiprocessing.get_context("spawn")
proc = ctx.Process(
target=launch_server,
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
# Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
try:
init_state = pipe_reader.recv()
except EOFError:
init_state = ""
if init_state != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
from sglang.srt.utils import kill_process_tree
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
def start_profile(self):
self.endpoint.start_profile()
def stop_profile(self):
self.endpoint.stop_profile()
def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
from sglang.srt.hf_transformers_utils import get_tokenizer
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
revision=self.server_args.revision,
)
async def async_generate(
self,
prompt: str,
sampling_params: Optional[Dict] = None,
):
if self.server_args.skip_tokenizer_init:
json_data = {
"input_ids": prompt,
"sampling_params": sampling_params,
"stream": True,
}
else:
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"stream": True,
}
pos = 0
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.post(self.generate_url, json=json_data) as response:
async for chunk, _ in response.content.iter_chunks():
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
pos += len(cur)
else:
yield data
add_request = async_generate
def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
json_data = {
"text": prompt,
"sampling_params": sampling_params,
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
"lora_path": lora_path,
}
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
response = requests.post(
self.url + "/generate",
json=json_data,
)
return json.dumps(response.json())
def encode(
self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
):
json_data = {"text": prompt}
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()

View File

@@ -0,0 +1,148 @@
import os
import warnings
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import vertexai
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Image,
)
except ImportError as e:
GenerativeModel = e
class VertexAI(BaseBackend):
def __init__(self, model_name, safety_settings=None):
super().__init__()
if isinstance(GenerativeModel, Exception):
raise GenerativeModel
project_id = os.environ["GCP_PROJECT_ID"]
location = os.environ.get("GCP_LOCATION")
vertexai.init(project=project_id, location=location)
self.model_name = model_name
self.chat_template = get_chat_template("default")
self.safety_settings = safety_settings
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
ret = GenerativeModel(self.model_name).generate_content(
prompt,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
comp = ret.text
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
generator = GenerativeModel(self.model_name).generate_content(
prompt,
stream=True,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
for ret in generator:
yield ret.text, {}
def text_to_vertexai_input(self, text, images):
input = []
# split with image token
text_segs = text.split(self.chat_template.image_token)
for image_path, image_base64_data in images:
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
input.append(Image.from_bytes(image_base64_data))
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
return input
def messages_to_vertexai_input(self, messages):
vertexai_message = []
# from openai message format to vertexai message format
for msg in messages:
if isinstance(msg["content"], str):
text = msg["content"]
else:
text = msg["content"][0]["text"]
if msg["role"] == "system":
warnings.warn("Warning: system prompt is not supported in VertexAI.")
vertexai_message.append(
{
"role": "user",
"parts": [{"text": "System prompt: " + text}],
}
)
vertexai_message.append(
{
"role": "model",
"parts": [{"text": "Understood."}],
}
)
continue
if msg["role"] == "user":
vertexai_msg = {
"role": "user",
"parts": [{"text": text}],
}
elif msg["role"] == "assistant":
vertexai_msg = {
"role": "model",
"parts": [{"text": text}],
}
# images
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
for image in msg["content"][1:]:
assert image["type"] == "image_url"
vertexai_msg["parts"].append(
{
"inline_data": {
"data": image["image_url"]["url"].split(",")[1],
"mime_type": "image/jpeg",
}
}
)
vertexai_message.append(vertexai_msg)
return vertexai_message

View File

@@ -0,0 +1,662 @@
import re
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum):
PLAIN = auto()
LLAMA2 = auto()
@dataclass
class ChatTemplate:
name: str
default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = ()
image_token: str = "<image>"
audio_token: str = "<audio>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(
self, role: str, hist_messages: List[Dict]
) -> Tuple[str, str]:
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
if self.style == ChatTemplateStyle.LLAMA2:
if role == "system" and not hist_messages:
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
system_prefix, system_suffix = self.role_prefix_and_suffix.get(
"system", ("", "")
)
return (user_prefix + system_prefix, system_suffix)
elif (
role == "user"
and len(hist_messages) == 1
and hist_messages[0]["content"] is not None
):
return ("", suffix)
return prefix, suffix
def get_prompt(self, messages: List[Dict]) -> str:
prompt = ""
for i, message in enumerate(messages):
role, content = message["role"], message["content"]
if role == "system" and content is None:
content = self.default_system_prompt
if content is None:
continue
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
prompt += f"{prefix}{content}{suffix}"
return prompt
chat_template_registry: Dict[str, ChatTemplate] = {}
matching_function_registry: List[Callable] = []
def register_chat_template(template):
chat_template_registry[template.name] = template
def register_chat_template_matching_function(func):
matching_function_registry.append(func)
def get_chat_template(name):
return chat_template_registry[name]
def get_chat_template_by_model_path(model_path):
for matching_func in matching_function_registry:
template_name = matching_func(model_path)
if template_name is not None:
return get_chat_template(template_name)
return get_chat_template("default")
register_chat_template(
ChatTemplate(
name="default",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("SYSTEM:", "\n"),
"user": ("USER:", "\n"),
"assistant": ("ASSISTANT:", "\n"),
},
)
)
register_chat_template(
ChatTemplate(
name="claude",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("\n\nHuman: ", ""),
"assistant": ("\n\nAssistant:", ""),
},
)
)
register_chat_template(
ChatTemplate(
name="chatml",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<image>\n",
)
)
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
register_chat_template(
ChatTemplate(
name="qwen",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_chat_template(
ChatTemplate(
name="qwen2-vl",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
default_system_prompt=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
role_prefix_and_suffix={
"system": ("", " "),
"user": ("USER:", " "),
"assistant": ("ASSISTANT:", "</s>"),
},
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="llama-2-chat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
style=ChatTemplateStyle.LLAMA2,
)
)
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="mistral",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
stop_str=("</s>",),
image_token="[IMG]",
)
)
register_chat_template(
ChatTemplate(
name="llama-3-instruct",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_header_id|>system<|end_header_id|>\n\n",
"<|eot_id|>",
),
"user": (
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|>",
),
"assistant": (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
),
},
stop_str=("<|eot_id|>",),
image_token="<|image|>",
)
)
# https://huggingface.co/openbmb/MiniCPM-V-2_6
register_chat_template(
ChatTemplate(
name="minicpmv",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
)
)
register_chat_template(
ChatTemplate(
name="janus-pro",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"",
"",
),
"User": (
"<User>",
"",
),
"assistant": (
"<Assistant>",
"<end▁of▁sentence>",
),
},
stop_str=("<end▁of▁sentence>",),
image_token="<image_placeholder>\n",
)
)
# https://huggingface.co/openbmb/MiniCPM-o-2_6
register_chat_template(
ChatTemplate(
name="minicpmo",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
)
)
register_chat_template(
ChatTemplate(
name="janus",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"",
"",
),
"user": (
"<User>",
"",
),
"assistant": (
"<Assistant>",
"<end▁of▁sentence>",
),
},
stop_str=("<end▁of▁sentence>",),
image_token="<image_placeholder>\n",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
name="llama-3-instruct-llava",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_header_id|>system<|end_header_id|>\n\n",
"<|eot_id|>",
),
"user": (
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|>",
),
"assistant": (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
),
},
stop_str=("<|eot_id|>",),
image_token="<image>\n",
)
)
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="llama-4",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|header_start|>system<|header_end|>\n\n",
"<|eot|>",
),
"user": (
"<|header_start|>user<|header_end|>\n\n",
"<|eot|>",
),
"assistant": (
"<|header_start|>assistant<|header_end|>\n\n",
"<|eot|>",
),
},
stop_str=("<|eot|>",),
image_token="<|image|>",
)
)
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
name="yi-1.5",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
"assistant": ("", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
register_chat_template(
ChatTemplate(
name="yi-vl",
default_system_prompt=(
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
),
role_prefix_and_suffix={
"system": ("", "\n\n"),
"user": ("### Human:", "\n"),
"assistant": ("### Assistant:", "\n"),
},
image_token=" <image_placeholder>\n",
)
)
register_chat_template(
ChatTemplate(
name="gemma-it",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
"assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
},
style=ChatTemplateStyle.PLAIN,
)
)
register_chat_template(
ChatTemplate(
name="dbrx-instruct",
default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>"),
"user": ("\n<|im_start|>user\n", "<|im_end|>"),
"assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"),
},
stop_str=("<|im_end|>",),
)
)
register_chat_template(
ChatTemplate(
name="c4ai-command-r",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
"assistant": (
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
"<|END_OF_TURN_TOKEN|>",
),
},
style=ChatTemplateStyle.PLAIN,
)
)
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
register_chat_template(
ChatTemplate(
name="internvl-2-5",
default_system_prompt="你是书生·万象英文名是InternVL是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
stop_str=["<|im_end|>", "<|action_end|>"],
)
)
register_chat_template(
ChatTemplate(
name="interns1",
default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
stop_str=["<|im_end|>", "<|action_end|>"],
)
)
register_chat_template(
ChatTemplate(
name="granite-3-instruct",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|start_of_role|>system<|end_of_role|>",
"<|end_of_text|>",
),
"user": (
"<|start_of_role|>user<|end_of_role|>",
"<|end_of_text|>",
),
"assistant": (
"<|start_of_role|>assistant<|end_of_role|>",
"<|end_of_text|>",
),
},
stop_str=("<|end_of_text|>",),
)
)
register_chat_template(
ChatTemplate(
name="deepseek-v3",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"",
"",
),
"user": (
"<User>",
"",
),
"assistant": (
"<Assistant>",
"<end▁of▁sentence>",
),
},
stop_str=("<end▁of▁sentence>",),
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
register_chat_template(
ChatTemplate(
name="glm-4v",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|system|>\n", "\n"),
"user": ("<|user|>\n", "\n"),
"assistant": ("<|assistant|>\n", "\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
image_token="<|image|>",
)
)
@register_chat_template_matching_function
def match_deepseek(model_path: str):
if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
r"base", model_path, re.IGNORECASE
):
return "deepseek-v3"
@register_chat_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if re.search(r"janus", model_path, re.IGNORECASE):
return "janus-pro"
@register_chat_template_matching_function
def match_dbrx(model_path: str):
if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
r"instruct", model_path, re.IGNORECASE
):
return "dbrx-instruct"
@register_chat_template_matching_function
def match_vicuna(model_path: str):
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
return "vicuna_v1.1"
@register_chat_template_matching_function
def match_llama2_chat(model_path: str):
if re.search(
r"llama-2.*chat|codellama.*instruct",
model_path,
re.IGNORECASE,
):
return "llama-2-chat"
@register_chat_template_matching_function
def match_mistral(model_path: str):
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
return "mistral"
@register_chat_template_matching_function
def match_llama3_instruct(model_path: str):
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
return "llama-3-instruct"
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
if re.search(r"tinyllama", model_path, re.IGNORECASE):
return "chatml"
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
return "qwen2-vl"
if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
return "glm-4v"
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
r"llava", model_path, re.IGNORECASE
):
return "qwen"
if re.search(
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
model_path,
re.IGNORECASE,
):
return "chatml-llava"
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
r"llava", model_path, re.IGNORECASE
):
return "yi-vl"
elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
return "yi-1.5"
@register_chat_template_matching_function
def match_gemma_it(model_path: str):
if re.search(r"gemma.*it", model_path, re.IGNORECASE):
return "gemma-it"
@register_chat_template_matching_function
def match_openbmb_minicpm(model_path: str):
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
return "minicpmv"
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
return "minicpmo"
@register_chat_template_matching_function
def match_c4ai_command_r(model_path: str):
if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
return "c4ai-command-r"
@register_chat_template_matching_function
def match_granite_instruct(model_path: str):
if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
return "granite-3-instruct"
@register_chat_template_matching_function
def match_gemma3_instruct(model_path: str):
if re.search(r"gemma-3", model_path, re.IGNORECASE):
return "gemma-it"
@register_chat_template_matching_function
def match_internvl_chat(model_path: str):
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
return "internvl-2-5"
@register_chat_template_matching_function
def match_interns1_chat(model_path: str):
if re.search(r"intern-s1", model_path, re.IGNORECASE):
return "interns1"
if re.search(r"interns1", model_path, re.IGNORECASE):
return "interns1"
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "What can you do?"},
{"role": "assistant", "content": "I can chat with you."},
]
template = get_chat_template("llama-2-chat")
print(template.get_prompt(messages))

View File

@@ -0,0 +1,164 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import numpy as np
@dataclass
class ChoicesDecision:
decision: str
meta_info: Optional[Dict[str, Any]] = None
class ChoicesSamplingMethod(ABC):
@property
def requires_unconditional_logprobs(self) -> bool:
return False
@abstractmethod
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision: ...
class TokenLengthNormalized(ChoicesSamplingMethod):
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option with the highest token length normalized prompt logprob."""
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
token_length_normalized = TokenLengthNormalized()
class GreedyTokenSelection(ChoicesSamplingMethod):
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option based on greedy logprob selection. For overlapping options
where one option is a subset of a longer option, extend the shorter option using
its average logprob for comparison against the longer option."""
num_options = len(choices)
max_tokens = max(len(option) for option in input_token_logprobs)
logprob_matrix = self._build_logprob_matrix(
input_token_logprobs, max_tokens, num_options
)
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
best_choice = choices[remaining[0]]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
"greedy_logprob_matrix": logprob_matrix.tolist(),
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
logprob_matrix = np.zeros((num_options, max_tokens))
for i, option in enumerate(input_token_logprobs):
actual_logprobs = [token[0] for token in option]
avg_logprob = np.mean(actual_logprobs)
logprob_matrix[i, : len(option)] = actual_logprobs
if len(option) < max_tokens:
logprob_matrix[i, len(option) :] = avg_logprob
return logprob_matrix
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
remaining = np.arange(num_options)
for j in range(max_tokens):
max_logprob = np.max(logprob_matrix[remaining, j])
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
if len(remaining) == 1:
break
return remaining
greedy_token_selection = GreedyTokenSelection()
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
@property
def requires_unconditional_logprobs(self) -> bool:
return True
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option with the highest average token logprob once normalized by
the unconditional token logprobs.
The first unconditional token logprob is assumed to be None. If so, it is
replaced with 0 for the purposes of normalization."""
if unconditional_token_logprobs is None:
raise ValueError(
"Unconditional token logprobs are required for this method."
)
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
input_token_logprobs, unconditional_token_logprobs
)
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
"unconditional_token_logprobs": unconditional_token_logprobs,
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
normalized_unconditional_prompt_logprobs = []
for inputs, unconditionals in zip(
input_token_logprobs, unconditional_token_logprobs
):
inputs_logprobs = np.array([token[0] for token in inputs])
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
normalized_unconditional_prompt_logprobs.append(
float(np.mean(inputs_logprobs - unconditionals_logprobs))
)
return normalized_unconditional_prompt_logprobs
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()

View File

@@ -0,0 +1,231 @@
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
def compile_func(function, backend):
tracer = function.trace(backend=backend)
compiler = CompiledFunction(tracer, function)
return compiler
class CompiledFunction:
def __init__(self, tracer, function):
self.function = function
self.last_node = CompGraphNode(tracer.last_node)
self.expr_to_node = {}
self.build_graph(tracer)
self.topological_sort()
def build_graph(self, tracer):
self.nodes = [self.last_node]
self.expr_to_node[tracer.last_node] = self.nodes[-1]
rename_pid = {}
visited = set([tracer.last_node])
head = 0
while head < len(self.nodes):
cur_node = self.nodes[head]
# add prev node
prev_node = cur_node.expr.prev_node
if prev_node is not None:
if prev_node not in visited:
visited.add(prev_node)
self.nodes.append(CompGraphNode(prev_node))
self.expr_to_node[prev_node] = self.nodes[-1]
cur_node.prev_node = self.expr_to_node[prev_node]
self.expr_to_node[prev_node].add_next_node(cur_node)
# add source node
if isinstance(cur_node.expr, SglVariable):
if cur_node.expr.name in tracer.variables:
source = tracer.variables[cur_node.expr.name].source
else:
source = cur_node.expr.source
if source not in visited:
visited.add(source)
self.nodes.append(CompGraphNode(source))
self.expr_to_node[source] = self.nodes[-1]
cur_node.source_node = self.expr_to_node[source]
self.expr_to_node[source].add_next_node(cur_node)
head += 1
# rename pid
if cur_node.expr.pid not in rename_pid:
rename_pid[cur_node.expr.pid] = len(rename_pid)
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
def topological_sort(self):
prevd = {}
cand = Queue()
for x in self.nodes:
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
if prevd[x] == 0:
cand.put(x)
new_list = []
while cand.qsize() > 0:
head = cand.get()
new_list.append(head)
for x in head.next_nodes:
prevd[x] -= 1
if prevd[x] == 0:
cand.put(x)
self.nodes = new_list
def print_graph(
self,
):
for node in self.nodes:
print(node)
def run_internal(
self,
backend,
kwargs,
default_sampling_para,
):
stream_executor_ids = set([x.expr.pid for x in self.nodes])
stream_executors = {}
for x in stream_executor_ids:
arguments = kwargs if x == self.last_node.expr.pid else {}
stream_executors[x] = StreamExecutor(
backend, arguments, default_sampling_para, None, False
)
for node in self.nodes:
se_id = node.expr.pid
expr = node.expr
if isinstance(expr, SglVariable):
# Make a copy for SglVariable
expr = SglVariable(expr.name, expr.source)
expr.source_stream_executor = stream_executors[
node.source_node.expr.pid
]
elif isinstance(expr, SglArgument):
# Substitute SglArgument
expr = kwargs[expr.name]
stream_executors[se_id].submit(expr)
for stream_executor in stream_executors.values():
stream_executor.end()
return ProgramState(stream_executors[self.last_node.expr.pid])
def run(
self,
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
**kwargs,
):
backend = backend or global_config.default_backend
kwargs.update(self.function.bind_arguments)
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
return self.run_internal(backend, kwargs, default_sampling_para)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
num_threads: Union[str, int] = "auto",
):
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
assert isinstance(batch_kwargs[0], dict)
backend = backend or global_config.default_backend
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:
cache_program(self.function, backend)
# Run all programs
if num_threads == "auto":
num_threads = multiprocessing.cpu_count()
num_threads = min(num_threads, len(batch_kwargs))
if num_threads == 1:
rets = []
for arguments in batch_kwargs:
rets.append(
self.run_internal(backend, arguments, default_sampling_para)
)
else:
with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_kwargs:
futures.append(
executor.submit(
self.run_internal, backend, arguments, default_sampling_para
)
)
rets = [f.result() for f in futures]
rets[-1].sync()
return rets
class CompGraphNode:
def __init__(
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
):
self.expr = expr
self.next_nodes = next_nodes or []
self.prev_node = prev_node
self.source_node = source_node
def add_next_node(self, other):
self.next_nodes.append(other)
def __repr__(self):
re = f"stream {self.expr.pid:2d}: "
re += f"%{self.expr.node_id} = "
if self.prev_node is not None:
re += f"%{self.prev_node.expr.node_id} + "
re += repr(self.expr)
return re

File diff suppressed because it is too large Load Diff

635
python/sglang/lang/ir.py Normal file
View File

@@ -0,0 +1,635 @@
"""The intermediate representation."""
import dataclasses
import inspect
import warnings
from typing import List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.choices import ChoicesSamplingMethod
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
REGEX_BOOL = r"(True|False)"
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass
class SglSamplingParams:
max_new_tokens: int = 128
min_new_tokens: int = 0
n: int = 1
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
min_p: float = 0.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
return_logprob: Optional[bool] = None
logprob_start_len: Optional[int] = (None,)
top_logprobs_num: Optional[int] = (None,)
return_text_in_logprobs: Optional[bool] = (None,)
json_schema: Optional[str] = None
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SglSamplingParams(
self.max_new_tokens,
self.min_new_tokens,
self.n,
self.stop,
self.stop_token_ids,
self.temperature,
self.top_p,
self.top_k,
self.min_p,
self.frequency_penalty,
self.presence_penalty,
self.ignore_eos,
self.return_logprob,
self.logprob_start_len,
self.top_logprobs_num,
self.return_text_in_logprobs,
self.json_schema,
)
def to_openai_kwargs(self):
# OpenAI does not support top_k, so we drop it here
if self.regex is not None:
warnings.warn("Regular expression is not supported in the OpenAI backend.")
return {
"max_tokens": self.max_new_tokens,
"max_completion_tokens": self.max_new_tokens,
"n": self.n,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_vertexai_kwargs(self):
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the VertexAI backend."
)
return {
"candidate_count": 1,
"max_output_tokens": self.max_new_tokens,
"stop_sequences": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k if self.top_k > 0 else None,
}
def to_anthropic_kwargs(self):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the Anthropic backend."
)
return {
"max_tokens": self.max_new_tokens,
"stop_sequences": (
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
def to_litellm_kwargs(self):
if self.regex is not None:
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"n": self.n,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
"json_schema": self.json_schema,
}
class SglFunction:
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
self.func = func
self.num_api_spec_tokens = num_api_spec_tokens
self.bind_arguments = bind_arguments or {}
self.pin_prefix_rid = None
# Parse arguments
argspec = inspect.getfullargspec(func)
assert argspec.args[0] == "s", 'The first argument must be "s"'
self.arg_names = argspec.args[1:]
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
def bind(self, **kwargs):
assert all(key in self.arg_names for key in kwargs)
new_bind_dict = {**self.bind_arguments, **kwargs}
return SglFunction(self.func, bind_arguments=new_bind_dict)
def run(
self,
*args,
max_new_tokens: int = 128,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
stream: bool = False,
backend=None,
use_thread: bool = True,
**kwargs,
):
from sglang.lang.interpreter import run_program
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
if stop is None:
stop = []
if stop_token_ids is None:
stop_token_ids = []
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
)
backend = backend or global_config.default_backend
return run_program(
self,
backend,
args,
kwargs,
default_sampling_para,
stream,
use_thread=use_thread,
)
def run_batch(
self,
batch_kwargs,
*,
max_new_tokens: int = 128,
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
generator_style: bool = False,
):
from sglang.lang.interpreter import run_program_batch
if stop is None:
stop = []
if stop_token_ids is None:
stop_token_ids = []
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
return []
if not isinstance(batch_kwargs[0], dict):
num_programs = len(batch_kwargs)
# change the list of argument values to dict of arg_name -> arg_value
batch_kwargs = [
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
for arg_values in batch_kwargs
if isinstance(arg_values, (list, tuple))
and len(self.arg_names) - len(self.arg_defaults)
<= len(arg_values)
<= len(self.arg_names)
]
# Ensure to raise an exception if the number of arguments mismatch
if len(batch_kwargs) != num_programs:
raise Exception("Given arguments mismatch the SGL function signature")
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
)
backend = backend or global_config.default_backend
return run_program_batch(
self,
backend,
batch_kwargs,
default_sampling_para,
num_threads,
progress_bar,
generator_style=generator_style,
)
def trace(self, *, backend=None, **kwargs):
from sglang.lang.tracer import trace_program
backend = backend or global_config.default_backend
return trace_program(self, kwargs, backend)
def cache(self, backend=None):
from sglang.lang.interpreter import cache_program
backend = backend or global_config.default_backend
return cache_program(self, backend)
def compile(self, *, backend=None):
from sglang.lang.compiler import compile_func
return compile_func(self, backend)
def __call__(self, *args, **kwargs):
from sglang.lang.tracer import TracingScope
tracing_scope = TracingScope.get_current_scope()
if tracing_scope is None:
return self.run(*args, **kwargs)
else:
kwargs["backend"] = tracing_scope.tracer_state.backend
return self.trace(*args, **kwargs)
class SglExpr:
node_ct = 0
def __init__(self):
self.node_id = SglExpr.node_ct
self.prev_node = None
self.pid = None
SglExpr.node_ct += 1
def __add__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr)
return self.concatenate_ir(self, other)
def __radd__(self, other):
if isinstance(other, str):
other = SglConstantText(other)
assert isinstance(other, SglExpr), f"{other}"
return self.concatenate_ir(other, self)
def concatenate_ir(self, a, b):
if isinstance(a, SglExprList):
if isinstance(b, SglExprList):
return SglExprList(a.expr_list + b.expr_list)
else:
return SglExprList(a.expr_list + [b])
elif isinstance(b, SglExprList):
return SglExprList([a] + b.expr_list)
return SglExprList([a, b])
def print_graph_dfs(self):
ret = [""]
visited = set()
def dfs_print(x):
if x is None or x in visited:
return
visited.add(x)
# Print dependency
if x.prev_node is not None:
dfs_print(x.prev_node)
if isinstance(x, SglExprList):
for y in x.expr_list:
dfs_print(y)
# elif isinstance(x, SglRole):
# dfs_print(x.expr)
elif isinstance(x, SglVariable):
dfs_print(x.source)
# Print the node itself
if isinstance(x, (SglFork, SglGetForkItem)):
ret[0] += f"%{x.node_id} = {x}\n"
else:
if x.prev_node is not None:
ret[0] += (
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
)
else:
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
dfs_print(self)
return ret[0]
class SglExprList(SglExpr):
def __init__(self, expr_list: List[SglExpr]):
super().__init__()
self.expr_list = expr_list
def __repr__(self):
return f"ExprList({self.expr_list})"
class SglArgument(SglExpr):
def __init__(self, name: str, value: str):
super().__init__()
self.name = name
self.value = value
def __repr__(self):
return f"Argument(name={self.name}, value={repr(self.value)})"
def __len__(self):
return len(self.value)
def __getitem__(self, i):
return self.value[i]
def __int__(self):
return self.value
def __bool__(self):
return self.value
def __format__(self, *args):
raise TypeError(
"Cannot put argument inside a f-string. "
"This is not compatible with the tracer. "
)
class SglImage(SglExpr):
def __init__(self, path: str):
self.path = path
def __repr__(self) -> str:
return f"SglImage({self.path})"
class SglVideo(SglExpr):
def __init__(self, path: str, num_frames: int):
self.path = path
self.num_frames = num_frames
def __repr__(self) -> str:
return f"SglVideo({self.path}, {self.num_frames})"
class SglGen(SglExpr):
def __init__(
self,
name: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
super().__init__()
self.name = name
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
dtype=dtype,
regex=regex,
json_schema=json_schema,
)
def __repr__(self):
return f"Gen('{self.name}')"
class SglConstantText(SglExpr):
def __init__(self, value: str):
super().__init__()
self.value = value
def __repr__(self):
return f"Constant({repr(self.value)})"
class SglRoleBegin(SglExpr):
def __init__(self, role: str):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleBegin({self.role})"
class SglRoleEnd(SglExpr):
def __init__(self, role: str):
super().__init__()
self.role = role
def __repr__(self):
return f"RoleEnd({self.role})"
class SglSelect(SglExpr):
def __init__(
self,
name: str,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
):
super().__init__()
self.name = name
self.choices = choices
self.temperature = temperature
self.choices_method = choices_method
def __repr__(self):
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
class SglFork(SglExpr):
def __init__(self, number: int, position_ids_offset=None):
super().__init__()
self.number = number
self.position_ids_offset = position_ids_offset
def __repr__(self):
return (
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
f"position_ids_offset={self.position_ids_offset})"
)
class SglGetForkItem(SglExpr):
def __init__(self, index: int):
super().__init__()
self.index = index
def __repr__(self):
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
class SglVariable(SglExpr):
def __init__(self, name: str, source):
super().__init__()
self.name = name
self.source = source
def __repr__(self):
return f"Variable('{self.name}', source=%{self.source.node_id})"
class SglVarScopeBegin(SglExpr):
def __init__(self, name: str):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeBegin('{self.name}')"
class SglVarScopeEnd(SglExpr):
def __init__(self, name: str):
super().__init__()
self.name = name
def __repr__(self):
return f"VarScopeEnd('{self.name}')"
class SglConcateAndAppend(SglExpr):
def __init__(self, states):
super().__init__()
self.states = states
def __repr__(self):
return f"ConcatenateAndAppend('{self.states}')"
class SglCommitLazy(SglExpr):
def __init__(self):
super().__init__()
def __repr__(self):
return "CommitLazy()"
class SglSeparateReasoning(SglExpr):
def __init__(self, model_type: str, expr: SglExpr):
super().__init__()
self.model_type = model_type
self.expr = expr
self.name = None
self._process_expr(expr)
def process_name_for_reasoning(self, name):
if not name:
raise ValueError("name must be provided")
return f"{name}_reasoning_content"
def _process_expr(self, expr):
if isinstance(expr, SglGen):
self.name = self.process_name_for_reasoning(expr.name)
elif isinstance(expr, SglSelect):
self.name = self.process_name_for_reasoning(expr.name)
elif isinstance(expr, SglExprList):
for x in expr.expr_list:
self._process_expr(x)
def __repr__(self):
return f"SeparateReasoning(model_type={self.model_type}, name={self.name})"

View File

@@ -0,0 +1,279 @@
"""Tracing a program."""
import uuid
from typing import Any, Dict, List, Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,
SglConstantText,
SglExpr,
SglExprList,
SglFork,
SglGen,
SglGetForkItem,
SglRoleBegin,
SglRoleEnd,
SglSelect,
SglVariable,
SglVarScopeBegin,
SglVarScopeEnd,
)
class StopTracing(Exception):
pass
def extract_prefix_by_tracing(program, backend):
# Create dummy arguments
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
arguments = dummy_arguments
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except (StopTracing, TypeError, AttributeError):
# Some exceptions may not be caught
pass
# Run and cache prefix
prefix = ""
for expr in tracer.flatten_nodes():
if isinstance(expr, SglConstantText):
prefix += expr.value
else:
break
return prefix
def trace_program(program, arguments, backend):
# Create dummy backend
if backend is None:
backend = BaseBackend()
# Create dummy arguments
dummy_arguments = {
name: SglArgument(name, None)
for name in program.arg_names
if name not in arguments
}
arguments.update(dummy_arguments)
arguments.update(program.bind_arguments)
# Trace
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
return tracer
class TracerProgramState(ProgramState):
def __init__(self, backend, arguments, only_trace_prefix):
self.pid = uuid.uuid4().hex
self.backend = backend
self.arguments: Dict[str, Any] = arguments
self.only_trace_prefix = only_trace_prefix
if hasattr(backend, "endpoint"):
self.backend = backend.endpoint
self.nodes = []
self.last_node = None
self.variables = {}
self.ret_value = None
# For completion
# For chat
self.messages_ = []
self.cur_role = None
self.chat_template = self.backend.get_chat_template()
# For multi states
self.child_states = []
cur_scope = TracingScope.get_current_scope()
if cur_scope is not None:
cur_scope.add_child_state(self)
##################################
########### Public API ###########
##################################
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
assert size >= 1
if self.only_trace_prefix:
raise StopTracing()
fork_node = SglFork(size)
fork_node.prev_node = self.last_node
states = [
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
for _ in range(size)
]
for i in range(size):
node = SglGetForkItem(i)
node.prev_node = fork_node
states[i].last_node = node
states[i].variables = dict(self.variables)
states[i].messages_ = list(self.messages_)
states[i].cur_role = self.cur_role
states[i].chat_template = self.chat_template
state_group = ProgramStateGroup(states, self)
return state_group
##################################
########## Internal API ##########
##################################
def _append_node(self, other: SglExpr):
self.nodes.append(other)
other.prev_node = self.last_node
self.last_node = other
def _execute(self, other: SglExpr):
if isinstance(other, str):
other = SglConstantText(other)
other.pid = self.pid
if isinstance(other, SglConstantText):
self._execute_fill(other)
elif isinstance(other, SglGen):
self._execute_gen(other)
elif isinstance(other, SglSelect):
self._execute_select(other)
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute(x)
elif isinstance(other, SglRoleBegin):
self._execute_role_begin(other)
elif isinstance(other, SglRoleEnd):
self._execute_role_end(other)
elif isinstance(other, SglVarScopeBegin):
self._execute_var_scope_begin(other)
elif isinstance(other, SglVarScopeEnd):
self._execute_var_scope_end(other)
else:
if self.only_trace_prefix:
raise StopTracing()
else:
self._append_node(other)
return self
def __iadd__(self, other):
self._execute(other)
return self
def _execute_fill(self, expr: SglConstantText):
if isinstance(expr, str):
expr = SglConstantText(expr)
self._append_node(expr)
def _execute_gen(self, expr: SglGen):
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_select(self, expr: SglSelect):
name = (
expr.name if expr.name is not None else "select_" + str(len(self.variables))
)
new_node = SglVariable(name, source=expr)
self.variables[name] = new_node
self._append_node(expr)
def _execute_role_begin(self, expr: SglRoleBegin):
assert self.cur_role is None, "Nested roles are not allowed."
if len(self.messages_) == 0 and expr.role != "system":
# Insert default system message
default_system = self.chat_template.default_system_prompt
if default_system:
self._execute_role_begin(SglRoleBegin("system"))
self._execute_fill(default_system)
self._execute_role_end(SglRoleEnd("system"))
self.cur_role = expr.role
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(prefix)
def _execute_role_end(self, expr: SglRoleEnd):
prefix, suffix = self.chat_template.get_prefix_and_suffix(
expr.role, self.messages_
)
self._execute_fill(suffix)
self.messages_.append({"role": expr.role, "content": ""})
self.cur_role = None
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
new_node = SglVariable(expr.name, source=self.last_node)
self.variables[expr.name] = new_node
def get_var(self, name):
ret = self.arguments.get(name, None)
if ret is not None:
return ret
v = self.variables[name]
return SglVariable(v.name, v.source)
def flatten_nodes(self):
def traverse(cur):
if isinstance(cur, SglExprList):
for child in cur.expr_list:
traverse(child)
else:
ret.append(cur)
ret = []
for x in self.nodes:
traverse(x)
return ret
def __del__(self):
pass
class TracingScope:
cur_scope = None
def __init__(self, tracer_state: TracerProgramState):
self.tracer_state = tracer_state
self.last_scope = TracingScope.cur_scope
def __enter__(self):
TracingScope.cur_scope = self
return self
def __exit__(self, exc_type, exc_value, traceback):
TracingScope.cur_scope = self.last_scope
@staticmethod
def get_current_scope():
return TracingScope.cur_scope
def add_child_state(self, state: TracerProgramState):
cur_scope = self
while cur_scope is not None:
cur_scope.tracer_state.child_states.append(state)
cur_scope = cur_scope.last_scope

View File

@@ -0,0 +1,16 @@
"""Launch the inference server."""
import os
import sys
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
try:
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)

166
python/sglang/profiler.py Normal file
View File

@@ -0,0 +1,166 @@
"""
Run live profiling.
Usage:
python3 -m sglang.profiler
"""
import argparse
import json
import os
import time
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional
import requests
PARENT_FOLDER = "/tmp/sglang-profile"
def _run_profile(
url: Optional[str],
num_steps: int,
activities: List[str],
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
) -> str:
if output_dir is None:
output_dir = PARENT_FOLDER
output_dir = os.path.normpath(output_dir)
output_dir = os.path.abspath(output_dir)
output_dir = Path(output_dir)
# Add "profile_name/timestamp" to the path.
if profile_name:
output_dir = output_dir / profile_name
output_dir = output_dir / str(time.time())
output_dir.mkdir(exist_ok=True, parents=True)
print(f"Dump profiling traces to {output_dir}")
print(
f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})"
)
# Dump server args.
file_path = Path(output_dir) / "server_args.json"
if not file_path.exists():
response = requests.get(url + "/get_server_info")
response.raise_for_status()
server_args_data = response.json()
with open(file_path, "w") as file:
file.write(json.dumps(server_args_data))
# Start profiler. The API replies when all steps are processed
# and files are generated.
json_data = {
"output_dir": str(output_dir),
"num_steps": str(num_steps),
"activities": activities,
"profile_by_stage": profile_by_stage,
}
response = requests.post(url=url + "/start_profile", json=json_data)
response.raise_for_status()
trace_link = str(output_dir)
return trace_link
def run_profile(
url: Optional[str],
num_steps: int,
activities: List[str],
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
):
# step based profile will self terminate on num_steps constraints
link = _run_profile(
url, num_steps, activities, output_dir, profile_name, profile_by_stage
)
return link
if __name__ == "__main__":
parser = ArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--url",
type=str,
default="http://localhost:30000",
help="Server or API base url if not using http host and port.",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Profile directory to dump profile traces.",
)
parser.add_argument(
"--profile-name",
type=str,
default=None,
help="The name of this profile run.",
)
parser.add_argument(
"--num-steps",
type=int,
default=5,
help="The number of forward steps to profile.",
)
parser.add_argument(
"--profile-by-stage",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="The number of forward steps to profile.",
)
parser.add_argument(
"--cpu",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
help="Whether to profile CPU activity",
)
parser.add_argument(
"--gpu",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
help="Whether to profile GPU activity",
)
parser.add_argument(
"--mem",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to memory usage (https://pytorch.org/memory_viz)",
)
parser.add_argument(
"--rpd",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
)
args = parser.parse_args()
activities = []
if args.cpu:
activities.append("CPU")
if args.gpu:
activities.append("GPU")
if args.mem:
activities.append("MEM")
if args.rpd:
activities.append("RPD")
run_profile(
args.url,
args.num_steps,
activities,
args.output_dir,
args.profile_name,
args.profile_by_stage,
)

View File

@@ -0,0 +1,177 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import logging
from typing import List, Optional, Tuple
import torch
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
)
if not is_hpu():
# ROCm does not use vllm custom allreduce
if use_vllm_custom_allreduce and not is_hip():
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
else:
try:
import sgl_kernel
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)
if not is_hip() and not is_npu():
if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar
else:
custom_op = sgl_kernel.allreduce
# custom allreduce
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
rank_data: torch.Tensor,
rank: int,
full_nvlink: bool,
) -> int:
return custom_op.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)
def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
custom_op.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
def dispose(fa: int) -> None:
custom_op.dispose(fa)
def meta_size() -> int:
return custom_op.meta_size()
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return custom_op.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return custom_op.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
custom_op.register_graph_buffers(fa, handles, offsets)
else:
# ROCM custom allreduce
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return sgl_kernel.allreduce.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
sgl_kernel.allreduce.dispose(fa)
def meta_size() -> int:
return sgl_kernel.allreduce.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return sgl_kernel.allreduce.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
# ROCM custom quick allreduce
def init_custom_qr(
rank: int, world_size: int, qr_max_size: Optional[int] = None
) -> int:
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
def qr_get_handle(fa: int) -> torch.Tensor:
return sgl_kernel.allreduce.qr_get_handle(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
sgl_kernel.allreduce.qr_open_handles(fa, handles)
def qr_all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
quant_level: int,
cast_bf2half: bool,
) -> None:
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
def qr_destroy(fa: int) -> None:
sgl_kernel.allreduce.qr_destroy(fa)
def qr_max_size() -> int:
return sgl_kernel.allreduce.qr_max_size()
def mscclpp_generate_unique_id() -> bytes:
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return sgl_kernel.allreduce.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)

View File

@@ -0,0 +1,100 @@
import asyncio
class RWLock:
def __init__(self):
# Protects internal state
self._lock = asyncio.Lock()
# Condition variable used to wait for state changes
self._cond = asyncio.Condition(self._lock)
# Number of readers currently holding the lock
self._readers = 0
# Whether a writer is currently holding the lock
self._writer_active = False
# How many writers are queued waiting for a turn
self._waiting_writers = 0
@property
def reader_lock(self):
"""
A context manager for acquiring a shared (reader) lock.
Example:
async with rwlock.reader_lock:
# read-only access
"""
return _ReaderLock(self)
@property
def writer_lock(self):
"""
A context manager for acquiring an exclusive (writer) lock.
Example:
async with rwlock.writer_lock:
# exclusive access
"""
return _WriterLock(self)
async def acquire_reader(self):
async with self._lock:
# Wait until there is no active writer or waiting writer
# to ensure fairness.
while self._writer_active or self._waiting_writers > 0:
await self._cond.wait()
self._readers += 1
async def release_reader(self):
async with self._lock:
self._readers -= 1
# If this was the last reader, wake up anyone waiting
# (potentially a writer or new readers).
if self._readers == 0:
self._cond.notify_all()
async def acquire_writer(self):
async with self._lock:
# Increment the count of writers waiting
self._waiting_writers += 1
try:
# Wait while either a writer is active or readers are present
while self._writer_active or self._readers > 0:
await self._cond.wait()
self._writer_active = True
finally:
# Decrement waiting writers only after we've acquired the writer lock
self._waiting_writers -= 1
async def release_writer(self):
async with self._lock:
self._writer_active = False
# Wake up anyone waiting (readers or writers)
self._cond.notify_all()
class _ReaderLock:
def __init__(self, rwlock: RWLock):
self._rwlock = rwlock
async def __aenter__(self):
await self._rwlock.acquire_reader()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._rwlock.release_reader()
class _WriterLock:
def __init__(self, rwlock: RWLock):
self._rwlock = rwlock
async def __aenter__(self):
await self._rwlock.acquire_writer()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._rwlock.release_writer()

View File

@@ -0,0 +1,137 @@
import os
import sys
from contextlib import nullcontext
import torch
# NOTE copied and modified from DeepGEMM
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
# NOTE copied and modified from DeepGEMM
def bench_kineto(
fn,
kernel_names,
num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: str = None,
flush_l2: bool = True,
with_multiple_kernels: bool = False,
):
# Conflict with Nsight Systems
using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
flush_l2_size = int(8e9 // 4)
# For some auto-tuning kernels with prints
fn()
# Profile
suppress = (
suppress_stdout_stderr
if suppress_kineto_output and not using_nsys
else nullcontext
)
with suppress():
schedule = (
torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
if not using_nsys
else None
)
profiler = (
torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
)
if not using_nsys
else nullcontext()
)
with profiler:
for i in range(2):
for _ in range(num_tests):
if flush_l2:
torch.empty(
flush_l2_size, dtype=torch.int, device="cuda"
).zero_()
fn()
if not using_nsys:
profiler.step()
# Return 1 if using Nsight Systems
if using_nsys:
return 1
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
prof_lines = (
profiler.key_averages()
.table(sort_by="cuda_time_total", max_name_column_width=100)
.split("\n")
)
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
if not with_multiple_kernels:
for name in kernel_names:
assert (
sum([name in line for line in prof_lines]) == 1
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
# Save chrome traces
if trace_path is not None:
profiler.export_chrome_trace(trace_path)
# Return average kernel times
units = {"ms": 1e3, "us": 1e6}
kernel_times = []
for name in kernel_names:
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
total_time += (
float(time_str.replace(unit, "")) / scale * int(num_str)
)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num)
return tuple(kernel_times) if is_tuple else kernel_times[0]

View File

@@ -0,0 +1,29 @@
from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig,
)
__all__ = [
"ExaoneConfig",
"ChatGLMConfig",
"DbrxConfig",
"DeepseekVL2Config",
"LongcatFlashConfig",
"MultiModalityConfig",
"KimiVLConfig",
"MoonViTConfig",
"Step3VLConfig",
"Step3TextConfig",
"Step3VisionEncoderConfig",
"Qwen3NextConfig",
]

View File

@@ -0,0 +1,78 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py
# ChatGLM2 and ChatGLM3 share the same config.
# ChatGLM4 is officially supported by Huggingface
# transformers >= 4.46.0 is required
# https://huggingface.co/docs/transformers/en/model_doc/glm
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
attribute_map = {
"num_hidden_layers": "num_layers",
"n_head_kv": "multi_query_group_num",
}
def __init__(
self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
interleaved_qkv=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
# It is to be compatible with long lora.
self.max_position_embeddings = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm
)
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
self.interleaved_qkv = interleaved_qkv
super().__init__(**kwargs)

View File

@@ -0,0 +1,279 @@
# Adapted from
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py
"""Dbrx configuration."""
from typing import Any, Optional
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
class DbrxAttentionConfig(PretrainedConfig):
"""Configuration class for Dbrx Attention.
[`DbrxAttention`] class. It is used to instantiate attention layers
according to the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
attn_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the attention layers.
clip_qkv (`float`, *optional*, defaults to None):
If not `None`, clip the queries, keys, and values in the attention layer to this value.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
rope_theta (float): The base frequency for rope.
"""
def __init__(
self,
attn_pdrop: float = 0,
clip_qkv: Optional[float] = None,
kv_n_heads: int = 1,
rope_theta: float = 10000.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.attn_pdrop = attn_pdrop
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, **kwargs: Any
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["attn_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all configurations of "
"models and can yield errors.",
config_dict["model_type"],
cls.model_type,
)
return cls.from_dict(config_dict, **kwargs)
class DbrxFFNConfig(PretrainedConfig):
"""Configuration class for Dbrx FFN.
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
the specified arguments, defining the layers architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
The dict should have a key 'name' with the value being the name of
the activation function along with any additional keyword arguments.
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
This should only be used for benchmarking purposes.
"""
def __init__(
self,
ffn_act_fn: Optional[dict] = None,
ffn_hidden_size: int = 3584,
moe_num_experts: int = 4,
moe_top_k: int = 1,
moe_jitter_eps: Optional[float] = None,
moe_loss_weight: float = 0.01,
moe_normalize_expert_weights: Optional[float] = 1,
uniform_expert_assignment: bool = False,
**kwargs: Any,
):
super().__init__()
if ffn_act_fn is None:
ffn_act_fn = {"name": "silu"}
self.ffn_act_fn = ffn_act_fn
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
self.uniform_expert_assignment = uniform_expert_assignment
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, **kwargs: Any
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "dbrx":
config_dict = config_dict["ffn_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
"You are using a model of type %s to instantiate a model of "
"type %s. This is not supported for all "
"configurations of models and can yield errors.",
config_dict["model_type"],
cls.model_type,
)
return cls.from_dict(config_dict, **kwargs)
class DbrxConfig(PretrainedConfig):
"""Configuration class for Dbrx.
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
d_model (`int`, *optional*, defaults to 6144):
Dimensionality of the embeddings and hidden states.
n_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the Transformer encoder.
n_layers (`int`, *optional*, defaults to 40):
Number of hidden layers in the Transformer encoder.
max_seq_len (`int`, *optional*, defaults to 32768):
The maximum sequence length of the model.
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
the `inputs_ids` passed when calling [`DbrxModel`].
resid_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability applied to the attention output before combining with residual.
emb_pdrop (`float`, *optional*, defaults to 0.0):
The dropout probability for the embedding layer.
attn_config (`dict`, *optional*):
A dictionary used to configure the model's attention module.
ffn_config (`dict`, *optional*):
A dictionary used to configure the model's FFN module.
use_cache (`bool`, *optional*, defaults to `False`):
Whether or not the model should return the last key/values attentions (not used by all models).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
Example:
```python
>>> from transformers import DbrxConfig, DbrxModel
>>> # Initializing a Dbrx configuration
>>> configuration = DbrxConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = DbrxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "dbrx"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
"max_position_embeddings": "max_seq_len",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
max_seq_len: int = 2048,
vocab_size: int = 32000,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_config: Optional[DbrxAttentionConfig] = None,
ffn_config: Optional[DbrxFFNConfig] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
output_router_logits: bool = False,
router_aux_loss_coef: float = 0.05,
**kwargs: Any,
):
if attn_config is None:
self.attn_config = DbrxAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = DbrxAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
if ffn_config is None:
self.ffn_config = DbrxFFNConfig()
elif isinstance(ffn_config, dict):
self.ffn_config = DbrxFFNConfig(**ffn_config)
else:
self.ffn_config = ffn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -0,0 +1,688 @@
import math
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image, ImageOps
from transformers import (
AutoProcessor,
LlamaTokenizerFast,
PretrainedConfig,
ProcessorMixin,
)
def select_best_resolution(image_size, candidate_resolutions):
# used for cropping
original_width, original_height = image_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in candidate_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale
)
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor
target_ids: torch.LongTensor
pixel_values: (
torch.Tensor
) # rename from "images" to "pixel_values" for compatibility
images_seq_mask: torch.BoolTensor
images_spatial_crop: torch.LongTensor
def __len__(self):
return len(self.input_ids)
class ImageTransform(object):
def __init__(
self,
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
normalize: bool = True,
):
self.mean = mean
self.std = std
self.normalize = normalize
# only load torchvision.transforms when needed
try:
import torchvision.transforms as T
# FIXME: add version check for gguf
except ImportError as err:
raise ImportError(
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
) from err
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class DeepseekVLV2Processor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast,
candidate_resolutions: Tuple[Tuple[int, int]],
patch_size: int,
downsample_ratio: int,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<▁pad▁>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.candidate_resolutions = candidate_resolutions
self.image_size = candidate_resolutions[0][0]
self.patch_size = patch_size
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
self.downsample_ratio = downsample_ratio
self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
self.tokenizer = tokenizer
# must set thispadding side with make a difference in batch inference
self.tokenizer.padding_side = "left"
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": pad_token})
# add image token
image_token_id = self.tokenizer.vocab.get(image_token)
if image_token_id is None:
special_tokens = [image_token]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token_id = self.tokenizer.vocab.get(image_token)
# add five special tokens for grounding-related tasks
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
# add special tokens for SFT data
special_tokens = ["<|User|>", "<|Assistant|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
"""play the role of format_messages_v2 and get_images_info in the last version"""
tokenized_data = []
masked_tokenized_data = [] # labels
images_list = []
images_seq_mask = []
images_spatial_crop = []
image_index = 0
image_token_cnt = messages.count(self.image_token)
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
messages,
pil_images[image_index : image_index + image_token_cnt],
bos=True,
eos=True,
cropping=len(pil_images) <= 2,
max_req_input_len=max_req_input_len,
)
image_index = image_token_cnt
tokenized_data += tokenized_str
if self.mask_prompt:
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
else:
masked_tokenized_data += tokenized_str
images_list += images
images_seq_mask += seq_mask
images_spatial_crop += spatial_crop
assert len(tokenized_data) == len(
images_seq_mask
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return (
tokenized_data,
masked_tokenized_data,
images_list,
images_seq_mask,
images_spatial_crop,
)
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image.Image] = None,
apply_sft_format: bool = False,
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
if conversations is not None, then it will always apply the SFT format to conversations;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert (
prompt is None or conversations is None
), "prompt and conversations cannot be used at the same time."
(
tokenized_str,
masked_tokenized_str,
images_list,
images_seq_mask,
images_spatial_crop,
) = self.format_messages_v2(conversations, images, max_req_input_len)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
if inference_mode:
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
images = torch.zeros((1, 3, self.image_size, self.image_size))
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
else:
images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0
) # stack the tensor to make it a batch of 1
prepare = VLChatProcessorOutput(
input_ids=input_ids,
target_ids=target_ids,
pixel_values=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image.Image] = None,
apply_sft_format: bool = False,
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
**kwargs,
):
prepare = self.process_one(
prompt=prompt,
conversations=conversations,
images=images,
apply_sft_format=apply_sft_format,
inference_mode=inference_mode,
system_prompt=system_prompt,
max_req_input_len=max_req_input_len,
)
return prepare
def find_all_indices(self, messages, target_value):
indices = []
for index, item in enumerate(messages):
if item == target_value:
indices.append(index)
return indices
def tokenize_with_images(
self,
conversation: str,
images: List[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
max_req_input_len: int = -1,
):
"""Tokenize text with <image> tags."""
images_list, images_seq_mask, images_spatial_crop = [], [], []
text_splits = conversation.split(self.image_token)
tokenized_str = []
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""select best resolution for anyres"""
if cropping:
best_width, best_height = select_best_resolution(
image.size, self.candidate_resolutions
)
else:
best_width, best_height = self.image_size, self.image_size
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
"""process the global view"""
global_view = ImageOps.pad(
image,
(self.image_size, self.image_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view))
"""process the local views"""
local_view = ImageOps.pad(
image,
(best_width, best_height),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
for i in range(0, best_height, self.image_size):
for j in range(0, best_width, self.image_size):
images_list.append(
self.image_transform(
local_view.crop(
(j, i, j + self.image_size, i + self.image_size)
)
)
)
"""record height / width crop num"""
num_width_tiles, num_height_tiles = (
best_width // self.image_size,
best_height // self.image_size,
)
images_spatial_crop.append([num_width_tiles, num_height_tiles])
"""add image tokens"""
h = w = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio
)
# global views tokens h * (w + 1), 1 is for line separator
tokenized_image = [self.image_token_id] * h * (w + 1)
# add a separator between global and local views
tokenized_image += [self.image_token_id]
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
tokenized_image += (
[self.image_token_id]
* (num_height_tiles * h)
* (num_width_tiles * w + 1)
)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
# print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
# deal with video, limit with request len
if max_req_input_len > -1:
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
tokenized_str = tokenized_str[:rest]
images_seq_mask = images_seq_mask[:rest]
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(
images_seq_mask
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return tokenized_str, images_list, images_seq_mask, images_spatial_crop
class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
model_type: str = "vision"
model_name: str = "siglip_large_patch16_384"
image_size: int = 384
patch_size: int = 16
width: int = 1024
layers: int = 24
heads: int = 16
mlp_ratio: int = 4
global_pool: str = "map"
ignore_head: bool = True
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
weight_init: str = "skip"
deterministic: bool = False
num_recomputing_layers: int = 0
def __init__(
self,
model_name: str = "siglip_large_patch16_384",
image_size: int = 384,
patch_size: int = 16,
width: int = 1024,
layers: int = 24,
heads: int = 16,
mlp_ratio: int = 4,
global_pool: str = "map",
ignore_head: bool = True,
class_token: bool = False,
num_classes: int = 0,
use_checkpoint: bool = False,
**kwargs,
):
self.model_name = model_name
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.layers = layers
self.heads = heads
self.mlp_ratio = mlp_ratio
self.global_pool = global_pool
self.ignore_head = ignore_head
self.class_token = class_token
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
super().__init__(**kwargs)
class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
model_type = "mlp_projector"
projector_type: str = "downsample_mlp_gelu"
input_dim: int = 1152
n_embed: int = 2048
depth: int = 2
mlp_ratio: int = 1
downsample_ratio: int = 2
token_pooling: bool = False
def __init__(
self,
projector_type: str = "downsample_mlp_gelu",
input_dim: int = 1152,
n_embed: int = 2048,
depth: int = 2,
mlp_ratio: int = 1,
downsample_ratio: int = 2,
**kwargs,
):
self.projector_type = projector_type
self.input_dim = input_dim
self.n_embed = n_embed
self.depth = depth
self.mlp_ratio = mlp_ratio
self.downsample_ratio = downsample_ratio
super().__init__(**kwargs)
class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=None,
n_routed_experts=None,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=None,
topk_group=None,
num_experts_per_tok=None,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
use_mla=True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = float(rms_norm_eps)
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_mla = use_mla
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekVL2Config(PretrainedConfig):
model_type = "deepseek_vl_v2"
vision_config: DeepseekVL2VisionEncoderConfig
projector_config: DeepseekVL2MlpProjectorConfig
language_config: DeepseekV2Config
tile_tag: str = "2D"
global_view_pos: str = "head"
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)
def __init__(
self,
tile_tag: str = "tile_tag",
global_view_pos: str = "head",
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),
**kwargs,
):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = DeepseekVL2VisionEncoderConfig(**vision_config)
projector_config = kwargs.get("projector_config", {})
self.projector_config = DeepseekVL2MlpProjectorConfig(**projector_config)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, DeepseekV2Config):
self.language_config = language_config
else:
self.language_config = DeepseekV2Config(**language_config)
self.tile_tag = tile_tag
self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions
self.architectures = ["DeepseekVL2ForCausalLM"]
AutoProcessor.register(DeepseekVL2Config, DeepseekVLV2Processor)

View File

@@ -0,0 +1,17 @@
import logging
from typing import Optional
import torch
logger = logging.getLogger(__name__)
class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "cuda") -> None:
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type)

View File

@@ -0,0 +1,195 @@
# coding=utf-8
# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved.
# Copyright 2024 The LG CNS AI Engineering Team.
# Copyright 2023-2024 SGLang Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" EXAONE model configuration """
from typing import Any, Dict
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {}
# ruff: noqa: E501
class ExaoneConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to
instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Exaone
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 102400):
Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model.
Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of
:class:`~transformers.EXAONEModel`.
max_position_embeddings (:obj:`int`, `optional`, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
hidden_size (:obj:`int`, `optional`, defaults to 2048):
Dimensionality of the encoder layers and the pooler layer.
num_layers (:obj:`int`, `optional`, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (:obj:`int`, `optional`):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"silu"`):
The non-linear activation function (function or string) in the decoder.
rope_theta (:obj:`float`, `optional`, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (:obj:`Dict`, `optional`):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (:obj:`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (:obj:`float`, `optional`):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (:obj:`int`, `optional`):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (:obj:`float`, `optional`):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (:obj:`float`, `optional`):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (:obj:`float`, `optional`):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (:obj:`List[float]`, `optional`):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (:obj:`List[float]`, `optional`):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (:obj:`float`, `optional`):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (:obj:`float`, `optional`):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
embed_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for the attention probabilities.
layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
The epsilon used by the layer normalization layers.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``configs.is_decoder=True``.
bos_token_id (:obj:`int`, `optional`, defaults to 0):
Beginning of stream token id.
eos_token_id (:obj:`int`, `optional`, defaults to 2):
End of stream token id.
tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to tie weight embeddings
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
>>> from transformers import EXAONEModel, ExaoneConfig
>>> # Initializing a EXAONE configuration
>>> configuration = ExaoneConfig()
>>> # Initializing a model from configuration
>>> model = EXAONEModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.configs
"""
model_type = "exaone"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_hidden_layers": "num_layers"}
def __init__(
self,
vocab_size=102400,
max_position_embeddings=2048,
hidden_size=2048,
num_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
intermediate_size=None,
activation_function="silu",
rope_theta=10000.0,
rope_scaling=None,
embed_dropout=0.0,
attention_dropout=0.0,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=0,
eos_token_id=2,
tie_word_embeddings=True,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_layers
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
if intermediate_size:
self.intermediate_size = intermediate_size
else:
self.intermediate_size = hidden_size * 4
self.activation_function = activation_function
self.embed_dropout = embed_dropout
self.attention_dropout = attention_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)

View File

@@ -0,0 +1,706 @@
import copy
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece as spm
from transformers import (
TOKENIZER_MAPPING,
GptOssConfig,
LlamaConfig,
PretrainedConfig,
PreTrainedTokenizer,
Qwen2Config,
Qwen3Config,
Qwen3MoeConfig,
)
from sglang.utils import logger
# Copied from: https://github.com/OpenGVLab/InternVL/blob/34a81000402bf8f716bab8c9b57aff1f6b436bd0/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py#L21
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = {}
# Modified from transformers.model.llama.configuration_llama.LlamaConfig
class InternLM2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`InternLM2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
Example:
"""
model_type = "internlm2"
_auto_class = "AutoConfig"
def __init__( # pylint: disable=W0102
self,
vocab_size=103168,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
bias=True,
rope_theta=10000,
rope_scaling=None,
attn_implementation="eager",
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.bias = bias
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attn_implementation = attn_implementation
if self.attn_implementation is None:
self.attn_implementation = "eager"
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if (
rope_scaling_factor is None
or not isinstance(rope_scaling_factor, (float, int))
or rope_scaling_factor < 1.0
):
raise ValueError(
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
)
if isinstance(rope_scaling_factor, int):
rope_scaling_factor = float(rope_scaling_factor)
class InternVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
instantiate a vision encoder according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
Number of color channels in the input images (e.g., 3 for RGB).
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries and values in the self-attention layers.
hidden_size (`int`, *optional*, defaults to 3200):
Dimensionality of the encoder layers and the pooler layer.
num_attention_heads (`int`, *optional*, defaults to 25):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 12800):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
qk_normalization (`bool`, *optional*, defaults to `True`):
Whether to normalize the queries and keys in the self-attention layers.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
use_flash_attn (`bool`, *optional*, defaults to `True`):
Whether to use flash attention mechanism.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
drop_path_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for stochastic depth.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 0.1):
A factor for layer scale.
"""
model_type = "intern_vit_6b"
def __init__(
self,
num_channels=3,
patch_size=14,
image_size=224,
qkv_bias=False,
hidden_size=3200,
num_attention_heads=25,
intermediate_size=12800,
qk_normalization=True,
num_hidden_layers=48,
use_flash_attn=True,
hidden_act="gelu",
layer_norm_eps=1e-6,
dropout=0.0,
drop_path_rate=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.dropout = dropout
self.drop_path_rate = drop_path_rate
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.qkv_bias = qkv_bias
self.qk_normalization = qk_normalization
self.use_flash_attn = use_flash_attn
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if "vision_config" in config_dict:
config_dict = config_dict["vision_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class InternVLChatConfig(PretrainedConfig):
model_type = "internvl_chat"
is_composition = True
def __init__(
self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
pad2square=False,
select_layer=-1,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
ps_version="v1",
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {"architectures": ["InternVisionModel"]}
logger.info(
"vision_config is None. Initializing the InternVisionConfig with default values."
)
if llm_config is None:
llm_config = {"architectures": ["InternLM2ForCausalLM"]}
logger.info(
"llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
)
self.vision_config = InternVisionConfig(**vision_config)
if llm_config.get("architectures")[0] == "LlamaForCausalLM":
self.llm_config = LlamaConfig(**llm_config)
elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM":
self.llm_config = InternLM2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
self.llm_config = Qwen2Config(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
self.llm_config = Qwen3MoeConfig(**llm_config)
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
self.llm_config = Qwen3Config(**llm_config)
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
self.llm_config = GptOssConfig(**llm_config)
else:
raise ValueError(
"Unsupported architecture: {}".format(
llm_config.get("architectures")[0]
)
)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.pad2square = pad2square
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.hidden_size = self.llm_config.hidden_size
# By default, we use tie_word_embeddings=False for models of all sizes.
self.tie_word_embeddings = False
self.llm_config.tie_word_embeddings = self.tie_word_embeddings
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["llm_config"] = self.llm_config.to_dict()
output["model_type"] = self.__class__.model_type
output["use_backbone_lora"] = self.use_backbone_lora
output["use_llm_lora"] = self.use_llm_lora
output["select_layer"] = self.select_layer
output["force_image_size"] = self.force_image_size
output["downsample_ratio"] = self.downsample_ratio
output["template"] = self.template
output["dynamic_image_size"] = self.dynamic_image_size
output["use_thumbnail"] = self.use_thumbnail
output["ps_version"] = self.ps_version
output["min_dynamic_patch"] = self.min_dynamic_patch
output["max_dynamic_patch"] = self.max_dynamic_patch
return output
# # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
# class InternLM2TokenizerFast(PreTrainedTokenizerFast):
# vocab_files_names = VOCAB_FILES_NAMES
# slow_tokenizer_class = InternLM2Tokenizer
# padding_side = 'left'
# model_input_names = ['input_ids', 'attention_mask']
# _auto_class = 'AutoTokenizer'
#
# def __init__(
# self,
# vocab_file,
# unk_token='<unk>',
# bos_token='<s>',
# eos_token='</s>',
# pad_token='</s>',
# sp_model_kwargs: Optional[Dict[str, Any]] = None,
# add_bos_token=True,
# add_eos_token=False,
# decode_with_prefix_space=False,
# clean_up_tokenization_spaces=False,
# **kwargs,
# ):
# super().__init__(
# vocab_file=vocab_file,
# unk_token=unk_token,
# bos_token=bos_token,
# eos_token=eos_token,
# pad_token=pad_token,
# sp_model_kwargs=sp_model_kwargs,
# add_bos_token=add_bos_token,
# add_eos_token=add_eos_token,
# decode_with_prefix_space=decode_with_prefix_space,
# clean_up_tokenization_spaces=clean_up_tokenization_spaces,
# **kwargs,
# )
# self._add_bos_token = add_bos_token
# self._add_eos_token = add_eos_token
# self.update_post_processor()
# self.vocab_file = vocab_file
#
# @property
# def can_save_slow_tokenizer(self) -> bool:
# return os.path.isfile(self.vocab_file) if self.vocab_file else False
#
# def update_post_processor(self):
# """
# Updates the underlying post processor with the current `bos_token` and `eos_token`.
# """
# bos = self.bos_token
# bos_token_id = self.bos_token_id
# if bos is None and self.add_bos_token:
# raise ValueError('add_bos_token = True but bos_token = None')
#
# eos = self.eos_token
# eos_token_id = self.eos_token_id
# if eos is None and self.add_eos_token:
# raise ValueError('add_eos_token = True but eos_token = None')
#
# single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
# pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
#
# special_tokens = []
# if self.add_bos_token:
# special_tokens.append((bos, bos_token_id))
# if self.add_eos_token:
# special_tokens.append((eos, eos_token_id))
# self._tokenizer.post_processor = processors.TemplateProcessing(
# single=single, pair=pair, special_tokens=special_tokens
# )
#
# @property
# def add_eos_token(self):
# return self._add_eos_token
#
# @property
# def add_bos_token(self):
# return self._add_bos_token
#
# @add_eos_token.setter
# def add_eos_token(self, value):
# self._add_eos_token = value
# self.update_post_processor()
#
# @add_bos_token.setter
# def add_bos_token(self, value):
# self._add_bos_token = value
# self.update_post_processor()
#
# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# if not self.can_save_slow_tokenizer:
# raise ValueError(
# 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
# 'tokenizer.'
# )
#
# if not os.path.isdir(save_directory):
# logger.error(f'Vocabulary path ({save_directory}) should be a directory')
# return
# out_vocab_file = os.path.join(
# save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
# )
#
# if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
# copyfile(self.vocab_file, out_vocab_file)
#
# return (out_vocab_file,)
# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
class InternLM2Tokenizer(PreTrainedTokenizer):
"""
Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
Args:
vocab_file (`str`):
Path to the vocabulary file.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
model_input_names = ["input_ids", "attention_mask"]
_auto_class = "AutoTokenizer"
def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token="</s>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
decode_with_prefix_space=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
print("register succeed")
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.decode_with_prefix_space = decode_with_prefix_space
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
self._no_prefix_space_tokens = None
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
@property
def no_prefix_space_tokens(self):
if self._no_prefix_space_tokens is None:
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
self._no_prefix_space_tokens = {
i for i, tok in enumerate(vocab) if not tok.startswith("")
}
return self._no_prefix_space_tokens
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
@property
def bos_token_id(self) -> Optional[int]:
return self.sp_model.bos_id()
@property
def eos_token_id(self) -> Optional[int]:
return self.sp_model.eos_id()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
"""Returns a tokenized string."""
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def _maybe_add_prefix_space(self, tokens, decoded):
if tokens and tokens[0] not in self.no_prefix_space_tokens:
return " " + decoded
else:
return decoded
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
out_string = self.clean_up_tokenization(out_string)
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
return out_string[1:]
def save_vocabulary(
self, save_directory, filename_prefix: Optional[str] = None
) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file
) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []
output = bos_token_ids + token_ids_0
if token_ids_1 is not None:
output = output + token_ids_1
if self.add_eos_token:
output = output + [self.eos_token_id]
return output
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
TOKENIZER_MAPPING.register(
InternVLChatConfig, (InternLM2Tokenizer, None), exist_ok=True
)

View File

@@ -0,0 +1,634 @@
# Adapted from:
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
import numpy as np
import PIL
import torch
from PIL.Image import Image
from transformers import (
BaseImageProcessor,
BatchFeature,
LlamaConfig,
LlamaTokenizerFast,
PretrainedConfig,
ProcessorMixin,
)
from transformers.image_utils import to_numpy_array
from sglang.srt.configs.utils import register_image_processor, register_processor
from sglang.srt.multimodal.mm_utils import expand2square
class DictToObject(dict):
def __init__(self, dictionary):
super(self).__init__(dictionary)
for key, value in dictionary.items():
if isinstance(value, dict):
value = DictToObject(value)
setattr(self, key, value)
class VisionConfig(PretrainedConfig):
model_type = "vision"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenAlignerConfig(PretrainedConfig):
model_type = "gen_aligner"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenHeadConfig(PretrainedConfig):
model_type = "gen_head"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class AlignerConfig(PretrainedConfig):
model_type = "aligner"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
class GenVisionConfig(PretrainedConfig):
model_type = "gen_vision"
cls: str = ""
params = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = kwargs.get("params", {})
@dataclass
class SigLIPVisionCfg:
width: int = 1152
layers: Union[Tuple[int, int, int, int], int] = 27
heads: int = 16
patch_size: int = 14
image_size: Union[Tuple[int, int], int] = 336
global_pool: str = "map"
mlp_ratio: float = 3.7362
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig
gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig
language_config: LlamaConfig
def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionConfig(**vision_config)
aligner_config = kwargs.get("aligner_config", {})
self.aligner_config = AlignerConfig(**aligner_config)
gen_vision_config = kwargs.get("gen_vision_config", {})
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
gen_aligner_config = kwargs.get("gen_aligner_config", {})
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
gen_head_config = kwargs.get("gen_head_config", {})
self.gen_head_config = GenHeadConfig(**gen_head_config)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
else:
self.language_config = LlamaConfig(**language_config)
class VLMImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (
0.48145466,
0.4578275,
0.40821073,
),
image_std: Union[Tuple[float, float, float], List[float]] = (
0.26862954,
0.26130258,
0.27577711,
),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.rescale_factor = rescale_factor
self.image_mean = image_mean
self.image_std = image_std
self.min_size = min_size
self.do_normalize = do_normalize
if image_mean is None:
self.background_color = (127, 127, 127)
else:
self.background_color = tuple([int(x * 255) for x in image_mean])
def resize(self, pil_img: Image) -> np.ndarray:
"""
Args:
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
Returns:
x (np.ndarray): [3, self.image_size, self.image_size]
"""
width, height = pil_img.size
max_size = max(width, height)
size = [
max(int(height / max_size * self.image_size), self.min_size),
max(int(width / max_size * self.image_size), self.min_size),
]
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
# print(f"orig size = {pil_img.size}, new size = {size}")
raise ValueError("Invalid size!")
def resize(
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
):
if isinstance(size, int):
w, h = pil_img.size
if (w <= h and w == size) or (h <= w and h == size):
return pil_img
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
size = (ow, oh)
else:
size = (size[1], size[0])
return pil_img.resize(
size, resample=interpolation, reducing_gap=None if antialias else 3.0
)
pil_img = resize(
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
)
pil_img = expand2square(pil_img, self.background_color)
x = to_numpy_array(pil_img)
# [H, W, 3] -> [3, H, W]
x = np.transpose(x, (2, 0, 1))
return x
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
# resize and pad to [self.image_size, self.image_size]
# then convert from [H, W, 3] to [3, H, W]
if not isinstance(images, list):
images = [images]
images: List[np.ndarray] = [self.resize(image) for image in images]
images = [image[:3, ...] for image in images]
# rescale from [0, 255] -> [0, 1]
images = [
self.rescale(
image=image,
scale=self.rescale_factor,
input_data_format="channels_first",
)
for image in images
]
# normalize
if self.do_normalize:
images = [
self.normalize(
image=image,
mean=self.image_mean,
std=self.image_std,
input_data_format="channels_first",
)
for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
@property
def default_shape(self):
return [3, self.image_size, self.image_size]
class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
sft_format: str
input_ids: torch.Tensor
pixel_values: torch.Tensor
num_image_tokens: torch.IntTensor
def __len__(self):
return len(self.input_ids)
@dataclass
class BatchedVLChatProcessorOutput(DictOutput):
sft_format: List[str]
input_ids: torch.Tensor
pixel_values: torch.Tensor
attention_mask: torch.Tensor
images_seq_mask: torch.BoolTensor
images_emb_mask: torch.BoolTensor
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
# hence AutoProcessor registration would not be affective in some cases
class VLChatProcessor(ProcessorMixin):
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["image_processor", "tokenizer"]
def __init__(
self,
image_processor: VLMImageProcessor,
tokenizer: LlamaTokenizerFast,
image_tag: str = "<image_placeholder>",
image_start_tag: str = "<begin_of_image>",
image_end_tag: str = "<end_of_image>",
pad_tag: str = "<▁pad▁>",
num_image_tokens: int = 576,
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.image_processor = image_processor
self.tokenizer = tokenizer
image_id = self.tokenizer.vocab.get(image_tag)
if image_id is None:
special_tokens = [image_tag]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
# print(f"Add image tag = {image_tag} to the tokenizer")
self.image_tag = image_tag
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.pad_tag = pad_tag
self.num_image_tokens = num_image_tokens
self.add_special_token = add_special_token
self.sft_format = sft_format
self.ignore_id = ignore_id
super().__init__(
image_processor,
tokenizer,
**kwargs,
)
@property
def image_token(self):
return self.image_tag
@property
def image_id(self) -> int:
image_id = self.tokenizer.vocab.get(self.image_tag)
return image_id
@property
def image_start_id(self):
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
return image_start_id
@property
def image_end_id(self):
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
return image_end_id
@property
def image_start_token(self):
return self.image_start_tag
@property
def image_end_token(self):
return self.image_end_tag
@property
def pad_id(self):
pad_id = self.tokenizer.vocab.get(self.pad_tag)
return pad_id
def add_image_token(
self,
image_indices: List[int],
input_ids: torch.LongTensor,
):
"""
Args:
image_indices (List[int]): [index_0, index_1, ..., index_j]
input_ids (torch.LongTensor): [N]
Returns:
input_ids (torch.LongTensor): [N + image tokens]
num_image_tokens (torch.IntTensor): [n_images]
"""
input_slices = []
start = 0
for index in image_indices:
if self.add_special_token:
end = index + 1
else:
end = index
# original text tokens
input_slices.append(input_ids[start:end])
# add boi, image tokens, eoi and set the mask as False
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
input_slices.append(
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
)
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
start = index + 1
# the left part
input_slices.append(input_ids[start:])
# concat all slices
input_ids = torch.cat(input_slices, dim=0)
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
return input_ids, num_image_tokens
def process_one(
self,
prompt: str = None,
images: List[Image] = None,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
sft_format = prompt
# tokenize
input_ids = self.tokenizer.encode(sft_format)
input_ids = torch.LongTensor(input_ids)
# add image tokens to the input_ids
image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)
image_indices = image_token_mask.nonzero()
input_ids, num_image_tokens = self.add_image_token(
image_indices=image_indices,
input_ids=input_ids,
)
# load images
images_outputs = self.image_processor(images, return_tensors="pt")
prepare = VLChatProcessorOutput(
sft_format=sft_format,
input_ids=input_ids,
pixel_values=images_outputs.pixel_values,
num_image_tokens=num_image_tokens,
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image] = None,
force_batchify: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
force_batchify (bool): force batchify the inputs;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(
prompt=prompt, conversations=conversations, images=images
)
if force_batchify:
prepare = self.batchify([prepare])
return prepare
def batchify(
self, prepare_list: List[VLChatProcessorOutput]
) -> BatchedVLChatProcessorOutput:
"""
Preprocesses the inputs for multimodal inference.
Args:
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
Returns:
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
"""
batch_size = len(prepare_list)
sft_format = []
n_images = []
seq_lens = []
for prepare in prepare_list:
n_images.append(len(prepare.num_image_tokens))
seq_lens.append(len(prepare))
input_token_max_len = max(seq_lens)
max_n_images = max(1, max(n_images))
batched_input_ids = torch.full(
(batch_size, input_token_max_len), self.pad_id
).long() # FIXME
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
batched_pixel_values = torch.zeros(
(batch_size, max_n_images, *self.image_processor.default_shape)
).float()
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
batched_images_emb_mask = torch.zeros(
(batch_size, max_n_images, self.num_image_tokens)
).bool()
for i, prepare in enumerate(prepare_list):
input_ids = prepare.input_ids
seq_len = len(prepare)
n_image = len(prepare.num_image_tokens)
# left-padding
batched_attention_mask[i, -seq_len:] = 1
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
if n_image > 0:
batched_pixel_values[i, :n_image] = prepare.pixel_values
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
batched_images_emb_mask[i, j, :n_image_tokens] = True
sft_format.append(prepare.sft_format)
batched_prepares = BatchedVLChatProcessorOutput(
input_ids=batched_input_ids,
attention_mask=batched_attention_mask,
pixel_values=batched_pixel_values,
images_seq_mask=batched_images_seq_mask,
images_emb_mask=batched_images_emb_mask,
sft_format=sft_format,
)
return batched_prepares
class VLMImageProcessorConfig(PretrainedConfig):
model_type = "deepseek_vlm"
image_size: int
min_size: int
image_mean: Union[Tuple[float, float, float], List[float]]
image_std: Union[Tuple[float, float, float], List[float]]
rescale_factor: float
do_normalize: bool
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (
0.48145466,
0.4578275,
0.40821073,
),
image_std: Union[Tuple[float, float, float], List[float]] = (
0.26862954,
0.26130258,
0.27577711,
),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True,
**kwargs,
):
self.image_size = image_size
self.min_size = min_size
self.image_mean = image_mean
self.image_std = image_std
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
super().__init__(**kwargs)
register_processor(MultiModalityConfig, VLChatProcessor)
register_image_processor(MultiModalityConfig, VLMImageProcessor)

View File

@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from typing import Optional, Union
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
class KimiVLConfig(PretrainedConfig):
model_type = "kimi_vl"
def __init__(
self,
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
ignore_index: int = -100,
media_placeholder_token_id: int = 163605,
pad_token_id: int = 0,
**kwargs
):
if vision_config is None:
vision_config = MoonViTConfig()
elif isinstance(vision_config, dict):
vision_config = MoonViTConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = DeepseekV2Config()
elif isinstance(text_config, dict):
text_config = DeepseekV2Config(**text_config)
self.text_config = text_config
self.ignore_index = ignore_index
self.media_placeholder_token_id = media_placeholder_token_id
super().__init__(pad_token_id=pad_token_id, **kwargs)

View File

@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from transformers.configuration_utils import PretrainedConfig
class MoonViTConfig(PretrainedConfig):
model_type = "moonvit"
def __init__(
self,
patch_size: int = 14,
init_pos_emb_height: int = 64,
init_pos_emb_width: int = 64,
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
hidden_size: int = 1152,
intermediate_size: int = 4304,
merge_kernel_size: tuple[int, int] = (2, 2),
**kwargs,
):
super().__init__(**kwargs)
self.patch_size = patch_size
# Positional embedding config
self.init_pos_emb_height = init_pos_emb_height
self.init_pos_emb_width = init_pos_emb_width
# Transformer config
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# Patch merger config
self.merge_kernel_size = merge_kernel_size

View File

@@ -0,0 +1,89 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import json
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
REMOTE = "remote"
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f
for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}"
)

View File

@@ -0,0 +1,104 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class LongcatFlashConfig(PretrainedConfig):
model_type = "longcat_flash"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=131072,
hidden_size=6144,
intermediate_size=None,
ffn_hidden_size=12288,
expert_ffn_hidden_size=2048,
num_layers=28,
num_hidden_layers=None,
num_attention_heads=64,
ep_size=1,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=128,
qk_nope_head_dim=128,
v_head_dim=128,
n_routed_experts=512,
moe_topk=12,
norm_topk_prob=False,
max_position_embeddings=131072,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mla_scale_q_lora=True,
mla_scale_kv_lora=True,
torch_dtype="bfloat16",
params_dtype="bfloat16",
rounter_params_dtype="float32",
router_bias=False,
topk_method=None,
routed_scaling_factor=6.0,
zero_expert_num=256,
zero_expert_type="identity",
nextn_use_scmoe=False,
num_nextn_predict_layers=1,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
params_dtype=params_dtype,
rounter_params_dtype=rounter_params_dtype,
topk_method=topk_method,
router_bias=router_bias,
nextn_use_scmoe=nextn_use_scmoe,
num_nextn_predict_layers=num_nextn_predict_layers,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = (
num_hidden_layers if num_hidden_layers is not None else num_layers
)
self.intermediate_size = (
intermediate_size if intermediate_size is not None else ffn_hidden_size
)
self.moe_intermediate_size = expert_ffn_hidden_size
self.num_attention_heads = num_attention_heads
self.ep_size = ep_size
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.n_routed_experts = n_routed_experts
self.moe_topk = moe_topk
self.norm_topk_prob = norm_topk_prob
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mla_scale_q_lora = mla_scale_q_lora
self.mla_scale_kv_lora = mla_scale_kv_lora
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
self.routed_scaling_factor = routed_scaling_factor
self.hidden_act = "silu"

View File

@@ -0,0 +1,811 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import logging
import math
import os
from enum import Enum, IntEnum, auto
from typing import List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_generation_config,
get_hf_text_config,
get_sparse_attention_config,
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
class AttentionArch(IntEnum):
MLA = auto()
MHA = auto()
class ModelImpl(str, Enum):
AUTO = "auto"
SGLANG = "sglang"
TRANSFORMERS = "transformers"
class ModelConfig:
def __init__(
self,
model_path: str,
trust_remote_code: bool = True,
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: str = "{}",
is_embedding: Optional[bool] = None,
enable_multimodal: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
# Parse args
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.model_impl = model_impl
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
**kwargs,
)
self.hf_generation_config = get_generation_config(
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)
self.is_hybrid = is_hybrid_model(
self.hf_config.architectures,
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
context_length=context_length,
attention_chunk_size=self.attention_chunk_size,
)
if self.is_hybrid is not None:
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
get_hybrid_layer_ids(
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
)
)
if enable_multimodal is None:
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False
logger.info(
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
)
else:
enable_multimodal = True
if (
is_draft_model
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
if (
is_draft_model
and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
):
self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
self.hf_config.architectures[0] = "MiMoMTP"
if (
is_draft_model
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
):
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None:
if context_length > derived_context_len:
reason = "Target model's" if is_draft_model else "User-specified"
msg = (
f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
)
if (
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
or is_in_ci() # FIXME: fix this special case
):
logger.warning(msg)
self.context_len = context_length
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
)
else:
self.context_len = context_length
else:
self.context_len = derived_context_len
# Unify the config keys for hf_text_config
self.head_dim = getattr(
self.hf_text_config,
"head_dim",
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
)
# FIXME: temporary special judge for MLA architecture
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures
or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
):
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
if self.hf_config.rope_scaling:
mscale_all_dim = self.hf_config.rope_scaling.get(
"mscale_all_dim", False
)
scaling_factor = self.hf_config.rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
self.head_dim = 128
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
self.hf_text_config, "use_mla", True
):
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
else:
if (
"MistralModel" in self.hf_config.architectures
or "MixtralForCausalLM" in self.hf_config.architectures
or "MistralForCausalLM" in self.hf_config.architectures
):
if getattr(self, "head_dim", None) is None:
self.head_dim = (
self.hf_config.hidden_size // self.hf_config.num_attention_heads
)
# In transformers==4.52.3, the head_dim is null in MistralConfig
if (
not hasattr(self.hf_text_config, "head_dim")
or self.hf_text_config.head_dim is None
):
setattr(self.hf_text_config, "head_dim", self.head_dim)
self.attention_arch = AttentionArch.MHA
self.num_attention_heads = self.hf_text_config.num_attention_heads
self.num_key_value_heads = getattr(
self.hf_text_config, "num_key_value_heads", None
)
# for Dbrx and MPT models
if self.hf_config.model_type in ["dbrx", "mpt"]:
self.num_key_value_heads = getattr(
self.hf_config.attn_config, "kv_n_heads", None
)
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_text_config.hidden_size
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.num_attention_layers = self.num_hidden_layers
if "LongcatFlashForCausalLM" in self.hf_config.architectures:
self.num_attention_layers = self.num_hidden_layers * 2
self.num_nextn_predict_layers = getattr(
self.hf_text_config, "num_nextn_predict_layers", None
)
self.vocab_size = self.hf_text_config.vocab_size
# Verify quantization
self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
# multimodal
self.image_token_id = getattr(
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
@staticmethod
def from_server_args(
server_args: ServerArgs,
model_path: str = None,
model_revision: str = None,
**kwargs,
):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=model_revision or server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
**kwargs,
)
def get_total_num_attention_heads(self) -> int:
return self.num_attention_heads
def get_num_attention_heads(self, tensor_parallel_size) -> int:
total_num_attention_heads = self.num_attention_heads
return max(1, total_num_attention_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False)
)
if not new_decoder_arch_falcon and getattr(
self.hf_text_config, "multi_query", False
):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For DBRX and MPT
if self.hf_config.model_type in ["mpt"]:
if "kv_n_heads" in self.hf_config.attn_config:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type in ["dbrx"]:
return getattr(
self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)
if self.hf_config.model_type in ["nemotron-nas"]:
nkvh = {
self.hf_config.num_attention_heads // block.attention.n_heads_in_group
for block in self.hf_config.block_configs
if not block.attention.no_op
}
if len(nkvh) == 0:
raise RuntimeError("Couldn't determine number of kv heads")
if len(nkvh) > 1:
raise ValueError(
"Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang"
)
return next(iter(nkvh))
attributes = [
# For Falcon:
"n_head_kv",
"num_kv_heads",
# For LLaMA-2:
"num_key_value_heads",
# For ChatGLM:
"multi_query_group_num",
# For Step3
"num_attention_groups",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
def get_num_kv_heads(self, tensor_parallel_size) -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // tensor_parallel_size)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
if quant_cfg is None:
# check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local:
import huggingface_hub
try:
from huggingface_hub import HfApi
hf_api = HfApi()
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
quant_cfg = modelopt_quant_config
except huggingface_hub.errors.OfflineModeIsEnabled:
logger.warning(
"Offline mode is enabled, skipping hf_quant_config.json check"
)
pass
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_config_file = os.path.join(
self.model_path, "hf_quant_config.json"
)
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
if quant_algo == "MIXED_PRECISION":
quant_cfg = {"quant_method": "w4afp8"}
else:
quant_cfg = modelopt_quant_config
return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
"gptq",
"fp8",
"compressed_tensors",
"compressed-tensors",
"fbgemm_fp8",
"w8a8_fp8",
"petit_nvfp4",
"quark",
"mxfp4",
]
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed_tensors",
"compressed-tensors",
"experts_int8",
"w8a8_int8",
"w8a8_fp8",
"moe_wna16",
"qoq",
"w4afp8",
"petit_nvfp4",
]
compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"],
"petit_nvfp4": ["modelopt"],
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
}
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get(
"quant_method", "" if not self.quantization else self.quantization
).lower()
# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
if quantization_override:
quant_method = quantization_override
self.quantization = quantization_override
break
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
if (
self.quantization not in compatible_quantization_methods
or quant_method
not in compatible_quantization_methods[self.quantization]
):
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization})."
)
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
if is_hip() and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm."
)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.",
self.quantization,
)
def _verify_dual_chunk_attention_config(self) -> None:
if hasattr(self.hf_config, "dual_chunk_attention_config"):
# Try loading the sparse attention config
sparse_attn_config = get_sparse_attention_config(self.model_path)
if not sparse_attn_config:
return
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
sparse_attn_config
)
if (
"sparse_attention_enabled"
not in self.hf_config.dual_chunk_attention_config
):
self.hf_config.dual_chunk_attention_config[
"sparse_attention_enabled"
] = True
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids is not None:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if eos_ids is None:
eos_ids = set()
if self.hf_generation_config:
generation_eos_ids = getattr(
self.hf_generation_config, "eos_token_id", None
)
if generation_eos_ids:
generation_eos_ids = (
{generation_eos_ids}
if isinstance(generation_eos_ids, int)
else set(generation_eos_ids)
)
eos_ids = eos_ids | generation_eos_ids
return eos_ids
def maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
directory in case of remote.
Args:
model: The model name or path.
"""
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
if is_remote_url(self.model_path):
logger.info("Pulling model configs from remote...")
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(self.model_path)
if is_remote_url(self.model_path):
client.pull_files(allow_pattern=["*config.json"])
self.model_weights = self.model_path
self.model_path = client.get_local_dir()
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if isinstance(config_dtype, str):
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
if config_dtype is None:
config_dtype = torch.float32
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type.startswith("gemma"):
if config.model_type == "gemma":
gemma_version = ""
else:
gemma_version = config.model_type[5]
logger.info(
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architecture
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
or "InternLM2ForRewardModel" in model_architectures
or "Qwen2ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures
or "Qwen3ForSequenceClassification" in model_architectures
or "CLIPModel" in model_architectures
or "BertModel" in model_architectures
or "Contriever" in model_architectures
or "BertForSequenceClassification" in model_architectures
or "XLMRobertaModel" in model_architectures
or "XLMRobertaForSequenceClassification" in model_architectures
):
return False
else:
return not is_embedding
multimodal_model_archs = [
"CLIPModel",
"DeepseekVL2ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3nForConditionalGeneration",
"Glm4vForConditionalGeneration",
"Glm4vMoeForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"Llama4ForConditionalGeneration",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
"LlavaForConditionalGeneration",
"LlavaVidForCausalLM",
"MiniCPMO",
"MiniCPMV",
"Mistral3ForConditionalGeneration",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",
"Phi4MMForCausalLM",
"VILAForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
def is_multimodal_model(model_architectures: List[str]):
if any(
multi_model_arch in model_architectures
for multi_model_arch in multimodal_model_archs
):
return True
else:
return False
def is_multimodal_gen_model(model_architectures: List[str]):
return False
def is_image_gen_model(model_architectures: List[str]):
return False
def is_audio_model(model_architectures: List[str]):
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
"""Check if chunked prefill is supported for a MultiModal model."""
unsupported = [
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"MllamaForConditionalGeneration",
"CLIPModel",
]
if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
return False
else:
return True
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def is_hybrid_model(
model_architectures: List[str],
hybrid_kvcache_ratio: Optional[float],
context_length: Optional[int],
attention_chunk_size: Optional[int],
):
if hybrid_kvcache_ratio is None:
return None
elif (
hybrid_kvcache_ratio > 0
and model_architectures[0] == "Llama4ForConditionalGeneration"
and context_length > attention_chunk_size
):
return hybrid_kvcache_ratio
else:
return None
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
if "Llama4ForConditionalGeneration" in model_architectures:
swa_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
]
full_attention_layer_ids = [
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
]
else:
swa_attention_layer_ids = None
full_attention_layer_ids = None
return swa_attention_layer_ids, full_attention_layer_ids

View File

@@ -0,0 +1,326 @@
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen3Hybrid model configuration"""
import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__)
# NOTE: HybridLayerType
class HybridLayerType(enum.Enum):
full_attention = "attention"
swa_attention = "swa_attention"
linear_attention = "linear_attention"
mamba2 = "mamba"
class Qwen3NextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
Qwen3-Next model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of
Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
`inputs_ids`.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
partial_rotary_factor (`float`, *optional*, defaults to 0.25):
Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
head_dim (`int`, *optional*, defaults to 256):
Projection weights dimension in multi-head attention.
linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
Kernel size of the convolution used in linear attention layers.
linear_key_head_dim (`int`, *optional*, defaults to 128):
Dimension of each key head in linear attention.
linear_value_head_dim (`int`, *optional*, defaults to 128):
Dimension of each value head in linear attention.
linear_num_key_heads (`int`, *optional*, defaults to 16):
Number of key heads used in linear attention layers.
linear_num_value_heads (`int`, *optional*, defaults to 32):
Number of value heads used in linear attention layers.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 512):
Intermediate size of the routed expert.
shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
Intermediate size of the shared expert.
num_experts_per_tok (`int`, *optional*, defaults to 10):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 512):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
layer_types (`list[str]`, *optional*, defaults to None):
Types of each layer (attention or linear).
```python
>>> from transformers import Qwen3NextModel, Qwen3NextConfig
>>> # Initializing a Qwen3Next style configuration
>>> configuration = Qwen3NextConfig()
>>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
>>> model = Qwen3NextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "qwen3_next"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=48,
num_attention_heads=16,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=0.25,
attention_bias=False,
attention_dropout=0.0,
head_dim=256,
linear_conv_kernel_dim=4,
linear_key_head_dim=128,
linear_value_head_dim=128,
linear_num_key_heads=16,
linear_num_value_heads=32,
decoder_sparse_step=1,
moe_intermediate_size=512,
shared_expert_intermediate_size=512,
num_experts_per_tok=10,
num_experts=512,
norm_topk_prob=True,
output_router_logits=False,
router_aux_loss_coef=0.001,
mlp_only_layers=[],
layer_types=None,
**kwargs,
):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim
rope_config_validation(self)
# linear attention (gdn now part)
self.linear_conv_kernel_dim = linear_conv_kernel_dim
self.linear_key_head_dim = linear_key_head_dim
self.linear_value_head_dim = linear_value_head_dim
self.linear_num_key_heads = linear_num_key_heads
self.linear_num_value_heads = linear_num_value_heads
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = mlp_only_layers
@property
def layers_block_type(self):
layer_type_list = []
for l in range(self.num_hidden_layers):
if (l + 1) % self.full_attention_interval == 0:
layer_type_list.append(HybridLayerType.full_attention.value)
else:
layer_type_list.append(HybridLayerType.linear_attention.value)
return layer_type_list
@property
def linear_layer_ids(self):
return [
i
for i, type_value in enumerate(self.layers_block_type)
if type_value == HybridLayerType.linear_attention.value
]
@property
def full_attention_layer_ids(self):
return [
i
for i, type_value in enumerate(self.layers_block_type)
if type_value == HybridLayerType.full_attention.value
]
@property
def hybrid_gdn_params(self):
world_size = get_attention_tp_size()
conv_dim = (
self.linear_key_head_dim * self.linear_num_key_heads * 2
+ self.linear_value_head_dim * self.linear_num_value_heads
)
conv_state_shape = (
divide(conv_dim, world_size),
self.linear_conv_kernel_dim - 1,
)
temporal_state_shape = (
divide(self.linear_num_value_heads, world_size),
self.linear_key_head_dim,
self.linear_value_head_dim,
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
)
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return (
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len

View File

@@ -0,0 +1,172 @@
from typing import Any, Optional, Union
from transformers.configuration_utils import PretrainedConfig
class Step3VisionEncoderConfig(PretrainedConfig):
model_type = "step3_vision_encoder"
def __init__(
self,
hidden_size=1792,
intermediate_size=3072,
output_hidden_size=4096,
num_hidden_layers=63,
num_attention_heads=16,
num_channels=3,
image_size=728,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.output_hidden_size = output_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
super().__init__(**kwargs)
class Step3TextConfig(PretrainedConfig):
model_type = "step3_text"
architectures = ["Step3TextForCausalLM"]
def __init__(
self,
hidden_size: int = 7168,
intermediate_size: int = 18432,
num_attention_heads: int = 64,
num_attention_groups: int = 1,
num_hidden_layers: int = 61,
max_seq_len: int = 65536,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 5120,
moe_num_experts: int = 48,
moe_top_k: int = 3,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 65536,
share_expert_dim: int = 5120,
share_q_dim: int = 2048,
head_dim: int = 256,
norm_expert_weight: bool = False,
moe_layers_enum: tuple[int] = (
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
super().__init__(**kwargs)
class Step3VLConfig(PretrainedConfig):
model_type = "step3_vl"
def __init__(
self,
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
text_config: Optional[Union[dict, Step3TextConfig]] = None,
understand_projector_stride: int = 1,
projector_bias: bool = True,
image_token_id: int = 128001,
**kwargs,
) -> None:
if vision_config is None:
vision_config = Step3VisionEncoderConfig()
elif isinstance(vision_config, dict):
vision_config = Step3VisionEncoderConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = Step3TextConfig()
elif isinstance(text_config, dict):
text_config = Step3TextConfig(**text_config)
self.text_config = text_config
self.understand_projector_stride = understand_projector_stride
self.projector_bias = projector_bias
self.hidden_size = text_config.hidden_size
self.image_token_id = image_token_id
super().__init__(**kwargs)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
from typing import TYPE_CHECKING
DEFAULT_MOE_PADDING_SIZE = 32
if TYPE_CHECKING:
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
def may_get_weight_block_size(model_config, load_config):
from sglang.srt.model_loader.loader import _get_quantization_config
from sglang.srt.model_loader.utils import get_model_architecture
model_class, _ = get_model_architecture(model_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
if quant_config is not None and hasattr(quant_config, "weight_block_size"):
return getattr(quant_config, "weight_block_size")
return None
def get_moe_padding_size(weight_block_size):
if weight_block_size is not None:
# See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
assert (
len(weight_block_size) == 2
), "Only len(weight_block_size) == 2 is supported"
assert (
weight_block_size[0] == weight_block_size[1]
), "Only weight_block_size[0] == weight_block_size[1] is supported"
return weight_block_size[0]
return DEFAULT_MOE_PADDING_SIZE
def get_num_heads_padding_size(tp_size, weight_block_size):
pad_size = (
tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
)
return pad_size
def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
attr_value = intermediate_padding_size
if hasattr(model_config, "hf_config") and hasattr(
model_config.hf_config, attr_name
):
attr_value = getattr(model_config.hf_config, attr_name)
elif hasattr(model_config, attr_name):
attr_value = getattr(model_config, attr_name)
if attr_value % intermediate_padding_size != 0:
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
if hasattr(model_config, "hf_config"):
setattr(model_config.hf_config, attr_name, attr_value)
if hasattr(model_config, "hf_text_config"):
setattr(model_config.hf_text_config, attr_name, attr_value)
else:
setattr(model_config, attr_name, attr_value)
return model_config
def adjust_config_with_unaligned_cpu_tp(
model_config: ModelConfig, load_config: LoadConfig, tp_size: int
) -> ModelConfig:
# Support the case where the num_attention_heads is not divisible by the TP size.
weight_block_size = may_get_weight_block_size(model_config, load_config)
model_config.hf_config.original_num_attention_heads = (
model_config.num_attention_heads
)
model_config.hf_text_config.original_num_attention_heads = (
model_config.num_attention_heads
)
model_config.hf_config.original_total_num_kv_heads = (
model_config.get_total_num_kv_heads()
)
model_config.hf_text_config.original_total_num_kv_heads = (
model_config.get_total_num_kv_heads()
)
if (
model_config.num_attention_heads % tp_size != 0
or model_config.get_total_num_kv_heads() % tp_size != 0
):
# Compute the head_dim using the model_config.num_attention_heads before padding
if not hasattr(model_config.hf_config, "head_dim"):
model_config.hf_config.head_dim = (
model_config.hidden_size // model_config.num_attention_heads
)
query_heads_per_kv = (
model_config.num_attention_heads // model_config.get_total_num_kv_heads()
)
total_kv_heads = model_config.get_total_num_kv_heads()
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
model_config.num_key_value_heads = num_key_value_heads
model_config.hf_config.num_key_value_heads = num_key_value_heads
model_config.hf_text_config.num_key_value_heads = num_key_value_heads
num_attention_heads = num_key_value_heads * query_heads_per_kv
model_config.num_attention_heads = num_attention_heads
model_config.hf_config.num_attention_heads = num_attention_heads
model_config.hf_text_config.num_attention_heads = num_attention_heads
intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
model_config = update_intermediate_size(
model_config, "moe_intermediate_size", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "intermediate_size", intermediate_padding_size
)
model_config = update_intermediate_size(
model_config, "intermediate_size_mlp", intermediate_padding_size
)
if (
hasattr(model_config.hf_config, "vision_config")
and model_config.hf_config.vision_config.model_type == "siglip_vision_model"
):
model_config.hf_config.vision_config.original_num_attention_heads = (
model_config.num_attention_heads
)
if model_config.hf_config.vision_config.num_attention_heads % tp_size != 0:
model_config.hf_config.vision_config.head_dim = (
model_config.hf_config.vision_config.hidden_size
// model_config.hf_config.vision_config.num_attention_heads
)
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
model_config.hf_config.vision_config.num_attention_heads = pad_vocab_size(
model_config.hf_config.vision_config.num_attention_heads, pad_size
)
model_config.hf_config.vision_config = update_intermediate_size(
model_config.hf_config.vision_config,
"intermediate_size",
intermediate_padding_size,
)
return model_config

View File

@@ -0,0 +1,25 @@
from typing import Type
from transformers import (
AutoImageProcessor,
AutoProcessor,
BaseImageProcessor,
PretrainedConfig,
ProcessorMixin,
)
def register_image_processor(
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
):
"""
register customized hf image processor while removing hf impl
"""
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
"""
register customized hf processor while removing hf impl
"""
AutoProcessor.register(config, processor, exist_ok=True)

View File

@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
import enum
import logging
from sglang.srt.connector.base_connector import (
BaseConnector,
BaseFileConnector,
BaseKVConnector,
)
from sglang.srt.connector.redis import RedisConnector
from sglang.srt.connector.s3 import S3Connector
from sglang.srt.utils import parse_connector_type
logger = logging.getLogger(__name__)
class ConnectorType(str, enum.Enum):
FS = "filesystem"
KV = "KV"
def create_remote_connector(url, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
elif connector_type == "s3":
return S3Connector(url)
else:
raise ValueError(f"Invalid connector type: {url}")
def get_connector_type(client: BaseConnector) -> ConnectorType:
if isinstance(client, BaseKVConnector):
return ConnectorType.KV
if isinstance(client, BaseFileConnector):
return ConnectorType.FS
raise ValueError(f"Invalid connector type: {client}")
__all__ = [
"BaseConnector",
"BaseFileConnector",
"BaseKVConnector",
"RedisConnector",
"S3Connector",
"ConnectorType",
"create_remote_connector",
"get_connector_type",
]

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
import os
import shutil
import signal
import tempfile
from abc import ABC, abstractmethod
from typing import Generator, List, Optional, Tuple
import torch
class BaseConnector(ABC):
"""
For fs connector such as s3:
<connector_type>://<path>/<filename>
For kv connector such as redis:
<connector_type>://<host>:<port>/<model_name>/keys/<key>
<connector_type://<host>:<port>/<model_name>/files/<filename>
"""
def __init__(self, url: str):
self.url = url
self.closed = False
self.local_dir = tempfile.mkdtemp()
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
def get_local_dir(self):
return self.local_dir
@abstractmethod
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
raise NotImplementedError()
@abstractmethod
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
raise NotImplementedError()
def close(self):
if self.closed:
return
self.closed = True
if os.path.exists(self.local_dir):
shutil.rmtree(self.local_dir)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self.close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
class BaseKVConnector(BaseConnector):
@abstractmethod
def get(self, key: str) -> Optional[torch.Tensor]:
raise NotImplementedError()
@abstractmethod
def getstr(self, key: str) -> Optional[str]:
raise NotImplementedError()
@abstractmethod
def set(self, key: str, obj: torch.Tensor) -> None:
raise NotImplementedError()
@abstractmethod
def setstr(self, key: str, obj: str) -> None:
raise NotImplementedError()
@abstractmethod
def list(self, prefix: str) -> List[str]:
raise NotImplementedError()
class BaseFileConnector(BaseConnector):
"""
List full file names from remote fs path and filter by allow pattern.
Args:
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full paths allowed by the pattern
"""
@abstractmethod
def glob(self, allow_pattern: str) -> List[str]:
raise NotImplementedError()

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Generator, List, Optional, Tuple
from urllib.parse import urlparse
import torch
from sglang.srt.connector import BaseKVConnector
from sglang.srt.connector.serde import create_serde
from sglang.srt.connector.utils import pull_files_from_db
logger = logging.getLogger(__name__)
class RedisConnector(BaseKVConnector):
def __init__(self, url: str):
import redis
super().__init__(url)
parsed_url = urlparse(url)
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
self.model_name = parsed_url.path.lstrip("/")
# TODO: more serde options
self.s, self.d = create_serde("safe")
def get(self, key: str) -> Optional[torch.Tensor]:
val = self.connection.get(key)
if val is None:
logger.error("Key %s not found", key)
return None
return self.d.from_bytes(val)
def getstr(self, key: str) -> Optional[str]:
val = self.connection.get(key)
if val is None:
logger.error("Key %s not found", key)
return None
return val.decode("utf-8")
def set(self, key: str, tensor: torch.Tensor) -> None:
assert tensor is not None
self.connection.set(key, self.s.to_bytes(tensor))
def setstr(self, key: str, obj: str) -> None:
self.connection.set(key, obj)
def list(self, prefix: str) -> List[str]:
cursor = 0
all_keys: List[bytes] = []
while True:
ret: Tuple[int, List[bytes]] = self.connection.scan(
cursor=cursor, match=f"{prefix}*"
) # type: ignore
cursor, keys = ret
all_keys.extend(keys)
if cursor == 0:
break
return [key.decode("utf-8") for key in all_keys]
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, bytes], None, None]:
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
for key in keys:
val = self.get(key)
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
yield key, val
def pull_files(
self,
allow_pattern: Optional[List[str]] = None,
ignore_pattern: Optional[List[str]] = None,
) -> None:
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
def close(self):
self.connection.close()
super().close()

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
import fnmatch
import os
from pathlib import Path
from typing import Generator, Optional, Tuple
import torch
from sglang.srt.connector import BaseFileConnector
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path
for path in paths
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
return [
path
for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def list_files(
s3,
path: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts = path.removeprefix("s3://").split("/")
prefix = "/".join(parts[1:])
bucket_name = parts[0]
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj["Key"] for obj in objects.get("Contents", [])]
paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)
if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)
return bucket_name, prefix, paths
class S3Connector(BaseFileConnector):
def __init__(self, url: str) -> None:
import boto3
super().__init__(url)
self.client = boto3.client("s3")
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
bucket_name, _, paths = list_files(
self.client, path=self.url, allow_pattern=allow_pattern
)
return [f"s3://{bucket_name}/{path}" for path in paths]
def pull_files(
self,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(
self.client, self.url, allow_pattern, ignore_pattern
)
if len(files) == 0:
return
for file in files:
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.client.download_file(bucket_name, file, destination_file)
def weight_iterator(
self, rank: int = 0
) -> Generator[Tuple[str, torch.Tensor], None, None]:
from sglang.srt.model_loader.weight_utils import (
runai_safetensors_weights_iterator,
)
# only support safetensor files now
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
return runai_safetensors_weights_iterator(hf_weights_files)
def close(self):
self.client.close()
super().close()

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# inspired by LMCache
from typing import Optional, Tuple
import torch
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
from sglang.srt.connector.serde.serde import Deserializer, Serializer
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
s: Optional[Serializer] = None
d: Optional[Deserializer] = None
if serde_type == "safe":
s = SafeSerializer()
d = SafeDeserializer()
else:
raise ValueError(f"Unknown serde type: {serde_type}")
return s, d
__all__ = [
"Serializer",
"Deserializer",
"SafeSerializer",
"SafeDeserializer",
"create_serde",
]

View File

@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch
from safetensors.torch import load, save
from sglang.srt.connector.serde.serde import Deserializer, Serializer
class SafeSerializer(Serializer):
def __init__(self):
super().__init__()
def to_bytes(self, t: torch.Tensor) -> bytes:
return save({"tensor_bytes": t.cpu().contiguous()})
class SafeDeserializer(Deserializer):
def __init__(self):
# TODO: dtype options
super().__init__(torch.float32)
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return load(bytes(b))["tensor_bytes"]
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return self.from_bytes_normal(b)

View File

@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
import abc
from abc import ABC, abstractmethod
import torch
class Serializer(ABC):
@abstractmethod
def to_bytes(self, t: torch.Tensor) -> bytes:
"""
Serialize a pytorch tensor to bytes. The serialized bytes should contain
both the data and the metadata (shape, dtype, etc.) of the tensor.
Input:
t: the input pytorch tensor, can be on any device, in any shape,
with any dtype
Returns:
bytes: the serialized bytes
"""
raise NotImplementedError
class Deserializer(metaclass=abc.ABCMeta):
def __init__(self, dtype):
self.dtype = dtype
@abstractmethod
def from_bytes(self, bs: bytes) -> torch.Tensor:
"""
Deserialize a pytorch tensor from bytes.
Input:
bytes: a stream of bytes
Output:
torch.Tensor: the deserialized pytorch tensor
"""
raise NotImplementedError

View File

@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
import os
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from sglang.srt.connector import BaseConnector
def parse_model_name(url: str) -> str:
"""
Parse the model name from the url.
Only used for db connector
"""
parsed_url = urlparse(url)
return parsed_url.path.lstrip("/")
def pull_files_from_db(
connector: BaseConnector,
model_name: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None,
) -> None:
prefix = f"{model_name}/files/"
local_dir = connector.get_local_dir()
files = connector.list(prefix)
for file in files:
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
with open(destination_file, "wb") as f:
f.write(connector.getstr(file).encode("utf-8"))

View File

@@ -0,0 +1,3 @@
# GPU Memory Types
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
GPU_MEMORY_TYPE_WEIGHTS = "weights"

View File

@@ -0,0 +1,213 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The baseclass of a backend for grammar-guided constrained decoding."""
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from threading import Event
from typing import Dict, List, Optional, Tuple
import torch
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
class BaseGrammarObject:
def __init__(self):
self._finished = False
def accept_token(self, token: int) -> None:
"""
Accept a token in the grammar.
"""
raise NotImplementedError()
def rollback(self, k: int):
raise NotImplementedError()
def is_terminated(self):
return False
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
raise NotImplementedError()
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
raise NotImplementedError()
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
raise NotImplementedError()
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError()
def copy(self) -> "BaseGrammarObject":
return self
@property
def finished(self):
return self._finished
@finished.setter
def finished(self, finished):
self._finished = finished
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
"""
Try to jump forward in the grammar.
Returns:
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
"""
raise NotImplementedError()
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
"""
Jump forward for the grammar.
Returns:
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
"""
raise NotImplementedError()
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
) -> None:
"""
Jump forward occurs, and update the grammar state if needed.
"""
raise NotImplementedError()
INVALID_GRAMMAR_OBJ = BaseGrammarObject()
@dataclass
class CacheEntry:
value: BaseGrammarObject
event: Event
class BaseGrammarBackend:
def __init__(self):
self.executor = ThreadPoolExecutor()
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
def _not_supported(self, key_type: str, key_string: str) -> None:
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
def dispatch_fallback(
self, key_type: str, key_string: str
) -> Optional[BaseGrammarObject]:
"""
This function should not be reached in any case.
"""
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("json", key_string)
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("regex", key_string)
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("ebnf", key_string)
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("structural_tag", key_string)
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
key_type, key_string = key
if key_type == "json":
return self.dispatch_json(key_string)
elif key_type == "regex":
return self.dispatch_regex(key_string)
elif key_type == "ebnf":
return self.dispatch_ebnf(key_string)
elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string)
elif key_type == "structural_pattern":
return self.dispatch_structural_pattern(key_string)
else:
return self.dispatch_fallback(key_type, key_string)
def get_cached_or_future_value(
self, key: Tuple[str, str]
) -> Optional[BaseGrammarObject]:
value = self.cache.get(key)
if value:
return value.copy(), True
value = self.executor.submit(self._init_value_dispatch, key)
return value, False
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
self.cache[key] = value
def reset(self):
self.cache.clear()
def create_grammar_backend(
server_args: ServerArgs,
tokenizer,
vocab_size: int,
eos_token_ids: Optional[set] = None,
) -> Optional[BaseGrammarBackend]:
if server_args.grammar_backend == "outlines":
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
grammar_backend = OutlinesGrammarBackend(
tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
elif server_args.grammar_backend == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
# Convert Set[int] to List[int] if needed
eos_list = list(eos_token_ids) if eos_token_ids else None
grammar_backend = XGrammarGrammarBackend(
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
)
elif server_args.grammar_backend == "llguidance":
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
grammar_backend = GuidanceBackend(
tokenizer=tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
elif server_args.grammar_backend == "none":
return None
else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
from sglang.srt.constrained.reasoner_grammar_backend import (
ReasonerGrammarBackend,
)
grammar_backend = ReasonerGrammarBackend(
grammar_backend, tokenizer.think_end_id
)
return grammar_backend

View File

@@ -0,0 +1,174 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constrained decoding with llguidance backend."""
import json
import logging
import os
from typing import List, Optional, Tuple
import torch
from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from
from llguidance.hf import from_tokenizer
from llguidance.torch import (
allocate_token_bitmask,
apply_token_bitmask_inplace,
fill_next_token_bitmask,
)
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
logger = logging.getLogger(__name__)
class GuidanceGrammar(BaseGrammarObject):
def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):
super().__init__()
self.llguidance_tokenizer = llguidance_tokenizer
self.serialized_grammar = serialized_grammar
self.ll_matcher = LLMatcher(
self.llguidance_tokenizer,
self.serialized_grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.finished = False
self.bitmask = None
def accept_token(self, token: int):
if not self.ll_matcher.consume_token(token):
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
self.finished = True
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
if self.ll_matcher.is_stopped():
self.finished = True
fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
# only create bitmask when batch gets larger
self.bitmask = allocate_token_bitmask(
batch_size, self.llguidance_tokenizer.vocab_size
)
bitmask = self.bitmask
else:
bitmask = self.bitmask[:batch_size]
return bitmask
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self):
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=self.serialized_grammar,
)
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
ff_tokens = self.ll_matcher.compute_ff_tokens()
if ff_tokens:
return ff_tokens, ""
else:
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
pass
class GuidanceBackend(BaseGrammarBackend):
def __init__(
self,
tokenizer,
whitespace_pattern: Optional[str] = None,
n_vocab: Optional[int] = None,
):
super().__init__()
self.tokenizer = tokenizer
self.whitespace_pattern = whitespace_pattern
self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
try:
return GuidanceGrammar(
llguidance_tokenizer=self.llguidance_tokenizer,
serialized_grammar=serialized_grammar,
)
except Exception as e:
logger.error(f"Hit invalid grammar: {serialized_grammar=}, {e=}")
return INVALID_GRAMMAR_OBJ
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
serialized_grammar = LLMatcher.grammar_from_json_schema(
key_string,
defaults={
"whitespace_pattern": self.whitespace_pattern,
},
)
except Exception as e:
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
serialized_grammar = grammar_from("regex", key_string)
return self._from_serialized(serialized_grammar)
def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
serialized_grammar = grammar_from("ebnf", key_string)
return self._from_serialized(serialized_grammar)
except ValueError as e:
logger.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
structural_tag = json.loads(key_string)
tags = [
StructTag(
begin=structure["begin"],
grammar=structure["schema"],
end=structure["end"],
trigger=structural_tag["triggers"][0], # TODO?
)
for structure in structural_tag["structures"]
]
g = StructTag.to_grammar(tags)
return self._from_serialized(g)
except Exception as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ

View File

@@ -0,0 +1,191 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constrained decoding with outlines backend."""
import json
import logging
from typing import Dict, List, Optional, Tuple, Union
import interegular
import torch
from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
try:
from outlines.fsm.json_schema import build_regex_from_schema
except ImportError:
from outlines_core.fsm.json_schema import build_regex_from_schema
logger = logging.getLogger(__name__)
class OutlinesGrammar(BaseGrammarObject):
def __init__(
self,
guide: RegexGuide,
jump_forward_map: Union[OutlinesJumpForwardMap, None],
) -> None:
super().__init__()
self.guide = guide
self.jump_forward_map = jump_forward_map
self.state = 0
self.finished = False
def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token)
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
).to(vocab_mask.device, non_blocking=True)
vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1)
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
logits.masked_fill_(vocab_mask, float("-inf"))
def copy(self):
return OutlinesGrammar(self.guide, self.jump_forward_map)
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
if not self.jump_forward_map:
return None
jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
return None
# preprocess the jump forward string
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = self.state
while (
len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
return suffix_ids, cur_state
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, cur_state = helper
return self.jump_forward_map.jump_forward_symbol(cur_state)
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
self.state = next_state
class OutlinesGrammarBackend(BaseGrammarBackend):
def __init__(
self,
tokenizer,
whitespace_pattern: bool,
):
super().__init__()
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
def fset(self, value):
self._value = value
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
self.whitespace_pattern = whitespace_pattern
def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
try:
if hasattr(RegexGuide, "from_regex"):
# outlines >= 0.1.1
guide = RegexGuide.from_regex(regex, self.outlines_tokenizer)
else:
# outlines <= 0.0.46
guide = RegexGuide(regex, self.outlines_tokenizer)
except interegular.patterns.InvalidSyntax as e:
logger.error(f"Hit invalid regex schema: {regex=}, {e=}")
return INVALID_GRAMMAR_OBJ
jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map)
def dispatch_ebnf(self, key_string: str):
return super().dispatch_ebnf(key_string)
def dispatch_structural_tag(self, key_string: str):
return super().dispatch_structural_tag(key_string)
def dispatch_json(self, key_string: str):
try:
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._compile_regex(regex)
def dispatch_regex(self, key_string: str):
return self._compile_regex(key_string)
def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
):
if isinstance(object, type(BaseModel)):
schema = json.dumps(object.model_json_schema())
elif isinstance(object, Dict):
schema = json.dumps(object)
else:
schema = object
return build_regex_from_schema(schema, whitespace_pattern)

View File

@@ -0,0 +1,200 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Faster constrained decoding with jump forward decoding / compressed finite state machine.
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
"""
import dataclasses
import logging
from collections import defaultdict
from typing import Optional
import interegular
from interegular import InvalidSyntax
from outlines.caching import cache
from sglang.srt.utils import get_bool_env_var
try:
# outlines >= 0.1.0
from outlines_core.fsm.outlines_core_rs import FSMInfo
from outlines_core.fsm.regex import make_byte_level_fsm, make_deterministic_fsm
except ImportError:
# outlines <= 0.0.46
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class JumpEdge:
symbol: str = None
symbol_next_state: int = None
byte: int = None
byte_next_state: int = None
def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
if not DISABLE_DISK_CACHE:
return cache(expire, typed, ignore)
else:
return lambda fn: None
@disk_cache()
def init_state_to_jump_forward(regex_string):
try:
regex_pattern = interegular.parse_pattern(regex_string)
except InvalidSyntax as e:
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
return
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
fsm_info: FSMInfo = regex_fsm.fsm_info
symbol_to_id = fsm_info.alphabet_symbol_mapping
id_to_symbol = {}
for symbol, id_ in symbol_to_id.items():
id_to_symbol.setdefault(id_, []).append(symbol)
transitions = fsm_info.transitions
outgoings_ct = defaultdict(int)
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
for s in fsm_info.finals:
outgoings_ct[s] = 1
state_to_jump_forward = {}
for (state, id_), next_state in transitions.items():
if id_ == fsm_info.alphabet_anything_value:
# Arbitrarily symbol cannot be recognized as jump forward
continue
symbols = id_to_symbol[id_]
for c in symbols:
if len(c) > 1:
# Skip byte level transitions like c = "5E"
continue
outgoings_ct[state] += 1
if outgoings_ct[state] > 1:
if state in state_to_jump_forward:
del state_to_jump_forward[state]
break
state_to_jump_forward[state] = JumpEdge(
symbol=c,
symbol_next_state=next_state,
)
# Process the byte level jump forward
outgoings_ct = defaultdict(int)
for s in fsm_info.finals:
outgoings_ct[s] = 1
for (state, id_), next_state in transitions.items():
if id_ == fsm_info.alphabet_anything_value:
continue
symbols = id_to_symbol[id_]
for c in symbols:
byte_ = None
if len(c) == 1 and ord(c) < 0x80:
# ASCII character
byte_ = ord(c)
elif len(c) > 1:
# FIXME: This logic is due to the leading \x00
# https://github.com/outlines-dev/outlines/pull/930
byte_ = int(symbols[0][1:], 16)
if byte_ is not None:
outgoings_ct[state] += 1
if outgoings_ct[state] > 1:
if state in state_to_jump_forward:
del state_to_jump_forward[state]
break
e = state_to_jump_forward.get(state, JumpEdge())
e.byte = byte_
e.byte_next_state = next_state
state_to_jump_forward[state] = e
return state_to_jump_forward
class OutlinesJumpForwardMap:
def __init__(self, regex_string):
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
def jump_forward_symbol(self, state):
jump_forward_str = ""
next_state = state
while state in self.state_to_jump_forward:
e = self.state_to_jump_forward[state]
if e.symbol is None:
break
jump_forward_str += e.symbol
next_state = e.symbol_next_state
state = next_state
return jump_forward_str, next_state
def jump_forward_byte(self, state):
if state not in self.state_to_jump_forward:
return None
jump_forward_bytes = []
next_state = None
while state in self.state_to_jump_forward:
e = self.state_to_jump_forward[state]
assert e.byte is not None and e.byte_next_state is not None
jump_forward_bytes.append((e.byte, e.byte_next_state))
next_state = e.byte_next_state
state = next_state
return jump_forward_bytes
def is_jump_forward_symbol_state(self, state):
return (
state in self.state_to_jump_forward
and self.state_to_jump_forward[state].symbol is not None
)
def test_main(regex_string):
jump_forward_map = OutlinesJumpForwardMap(regex_string)
for state, e in jump_forward_map.state_to_jump_forward.items():
if e.symbol is not None:
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
print(f"{state} -> {next_state}", jump_forward_str)
bytes_ = jump_forward_map.jump_forward_byte(state)
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
if __name__ == "__main__":
import outlines
outlines.caching.clear_cache()
test_main(r"The google's DNS sever address is " + IP_REGEX)
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
test_main(r"[-+]?[0-9]+[ ]*")

View File

@@ -0,0 +1,90 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
from typing import List, Optional, Tuple
import torch
from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
class ReasonerGrammarObject(BaseGrammarObject):
def __init__(self, grammar: BaseGrammarObject, think_end_id):
super().__init__()
self.grammar = grammar
self.think_end_id = think_end_id
self.is_in_reasoning = True
def accept_token(self, token: int):
if token == self.think_end_id:
self.is_in_reasoning = False
if not self.is_in_reasoning and token != self.think_end_id:
self.grammar.accept_token(token)
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
if not self.is_in_reasoning:
self.grammar.fill_vocab_mask(vocab_mask, idx)
def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor:
return self.grammar.move_vocab_mask(vocab_mask, device)
@property
def apply_vocab_mask(self):
return self.grammar.apply_vocab_mask
def copy(self) -> BaseGrammarObject:
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
@property
def finished(self):
return self.grammar.finished
@finished.setter
def finished(self, finished):
self.grammar.finished = finished
def try_jump_forward(self, tokenizer):
return self.grammar.try_jump_forward(tokenizer)
def jump_forward_str_state(self, helper):
return self.grammar.jump_forward_str_state(helper)
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
return self.grammar.jump_and_retokenize(
old_output_ids, new_output_ids, next_state
)
class ReasonerGrammarBackend(BaseGrammarBackend):
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
super().__init__()
self.grammar_backend = grammar_backend
self.think_end_id = think_end_id
def _init_value_dispatch(
self, key: Tuple[str, str]
) -> Optional[ReasonerGrammarObject]:
ret = self.grammar_backend._init_value_dispatch(key)
if ret is None:
return None
return ReasonerGrammarObject(ret, self.think_end_id)

View File

@@ -0,0 +1,141 @@
# Adapt from
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
from typing import List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_core_count
@triton.jit
def apply_token_bitmask_inplace_kernel(
logits_ptr,
bitmask_ptr,
indices_ptr,
num_rows,
vocab_size,
logits_strides,
bitmask_strides,
NUM_SMS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
the masked logits will be set to -inf.
Parameters
----------
logits_ptr : tl.tensor
Pointer to the logits tensor to apply the bitmask to.
bitmask_ptr : tl.tensor
Pointer to the bitmask tensor to apply.
indices_ptr : Optional[tl.tensor]
Optional pointer to indices tensor specifying which rows to apply the mask to.
num_rows : int
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
vocab_size : int
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
logits_strides : int
Stride between rows in the logits tensor.
bitmask_strides : int
Stride between rows in the bitmask tensor.
NUM_SMS : int
Number of streaming multiprocessors to use.
BLOCK_SIZE : int
Size of processing blocks.
"""
pid = tl.program_id(0)
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
row_id = work_id // num_blocks
block_offset = (work_id % num_blocks) * BLOCK_SIZE
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_strides
packed_bitmask = tl.load(
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
packed_bitmask_mask,
)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
tl.store(
logits_ptr + batch_id * logits_strides + offsets,
-float("inf"),
vocab_mask & bitmask,
)
def apply_token_bitmask_inplace_triton(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
):
NUM_SMS = get_device_core_count()
BLOCK_SIZE = 4096
BITS_PER_BLOCK = 32
# Check input dtype
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
# Check input tensor shapes.
logits_shape = logits.shape
bitmask_shape = bitmask.shape
if logits.ndim == 1:
logits_shape = (1, logits_shape[0])
if bitmask.ndim == 1:
bitmask_shape = (1, bitmask_shape[0])
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
assert required_bitmask_width >= bitmask_shape[1], (
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
)
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
num_rows = None
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
num_rows = indices.shape[0]
else:
assert (
logits_shape[0] == bitmask_shape[0]
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
num_rows = logits_shape[0]
if NUM_SMS > 0:
grid = (NUM_SMS,)
else:
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
grid = (num_rows * num_blocks,)
NUM_SMS = triton.next_power_of_2(grid[0])
apply_token_bitmask_inplace_kernel[grid](
logits,
bitmask,
indices,
num_rows,
vocab_size,
logits_shape[1],
bitmask_shape[1],
NUM_SMS,
BLOCK_SIZE,
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
num_stages=3,
)

View File

@@ -0,0 +1,239 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constrained decoding with xgrammar backend."""
import json
import logging
from typing import List, Optional, Tuple, Union
import torch
from xgrammar import (
CompiledGrammar,
GrammarCompiler,
GrammarMatcher,
StructuralTagItem,
TokenizerInfo,
allocate_token_bitmask,
)
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
from sglang.srt.utils import is_hip
_is_hip = is_hip()
if _is_hip:
from sgl_kernel import apply_token_bitmask_inplace_cuda
else:
from sglang.srt.constrained.triton_ops.bitmask_ops import (
apply_token_bitmask_inplace_triton,
)
logger = logging.getLogger(__name__)
MAX_ROLLBACK_TOKENS = 200
class XGrammarGrammar(BaseGrammarObject):
def __init__(
self,
matcher: GrammarMatcher,
vocab_size: int,
ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]],
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
) -> None:
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.override_stop_tokens = override_stop_tokens
self.finished = False
self.accepted_tokens = []
self.key_string = key_string
def accept_token(self, token: int):
if not self.is_terminated():
accepted = self.matcher.accept_token(token)
if not accepted:
# log for debugging
raise ValueError(
f"Tokens not accepted: {token}\n"
f"Accepted tokens: {self.accepted_tokens}\n"
f"Key string: {self.key_string}"
)
else:
self.accepted_tokens.append(token)
def rollback(self, k: int):
self.matcher.rollback(k)
self.accepted_tokens = self.accepted_tokens[:-k]
def is_terminated(self):
return self.matcher.is_terminated()
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return allocate_token_bitmask(batch_size, vocab_size)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if logits.device.type == "cuda":
if _is_hip:
apply_token_bitmask_inplace_cuda(logits, vocab_mask)
else:
apply_token_bitmask_inplace_triton(logits, vocab_mask)
elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
self.apply_vocab_mask_cpu(logits, vocab_mask)
else:
raise RuntimeError(f"Unsupported device: {logits.device.type}")
def copy(self):
matcher = GrammarMatcher(
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher,
self.vocab_size,
self.ctx,
self.override_stop_tokens,
self.key_string,
)
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
s = self.matcher.find_jump_forward_string()
if s:
return [], s
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, data = helper
return data, -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else:
break
# rollback to the last token that is the same
if k < len(old_output_ids):
self.matcher.rollback(len(old_output_ids) - k)
for i in range(k, len(new_output_ids)):
assert self.matcher.accept_token(new_output_ids[i])
def __repr__(self):
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
class XGrammarGrammarBackend(BaseGrammarBackend):
def __init__(
self,
tokenizer,
vocab_size: int,
model_eos_token_ids: Optional[List[int]] = None,
):
super().__init__()
if hasattr(tokenizer, "init_xgrammar"):
# For special tokenizer
tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
else:
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
)
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
matcher = GrammarMatcher(
ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
)
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
if key_string == "$$ANY$$":
# Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e:
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
def reset(self):
self.grammar_compiler.clear_cache()

View File

@@ -0,0 +1,102 @@
from torch import nn
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_hip,
is_npu,
is_xpu,
)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu()
_is_xpu = is_xpu()
class CustomOp(nn.Module):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
# States for torch.compile
self._original_forward_method = None
self.is_torch_compile = False
def enter_torch_compile(self, num_tokens: int):
# Skip if Op is already entered compile mode.
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
# among layers and `enter_torch_compile` will be called many times.
# We should prevent `self._original_forward_method` from being overridden when
# it is not the first time `enter_torch_compile` called.
if self.is_torch_compile:
return
self._original_forward_method = self._forward_method
# NOTE: Temporarily workaround MoE
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs=1
if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1:
from sglang.srt.layers.moe.fused_moe_native import (
fused_moe_forward_native,
)
self._forward_method = fused_moe_forward_native
elif "TopK" in self.__class__.__name__:
if num_tokens == 1:
self._forward_method = self.forward_native
else:
self._forward_method = self.forward_native
self.is_torch_compile = True
def leave_torch_compile(self):
# Skip if Op is already exited compile mode.
if not self.is_torch_compile:
return
self._forward_method = self._original_forward_method
self._original_forward_method = None
self.is_torch_compile = False
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def forward_npu(self, *args, **kwargs):
raise NotImplementedError
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_xpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward_hpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
if _is_cuda:
return self.forward_cuda
elif _is_hip:
return self.forward_hip
elif _is_cpu and _is_cpu_amx_available:
return self.forward_cpu
elif _is_npu:
return self.forward_npu
elif _is_xpu:
return self.forward_xpu
else:
return self.forward_native

View File

@@ -0,0 +1,168 @@
import argparse
import functools
from pathlib import Path
import polars as pl
import torch
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
from sglang.srt.debug_utils.dumper import get_truncated_value
def main(args):
df_target = read_meta(args.target_path)
df_target = df_target.sort("rank", "dump_index")
df_target = df_target.filter(
(pl.col("forward_pass_id") >= args.start_id)
& (pl.col("forward_pass_id") <= args.end_id)
)
assert all(
c in df_target.columns
for c in ["rank", "forward_pass_id", "dump_index", "name"]
)
df_baseline = read_meta(args.baseline_path)
print("df_target", df_target)
print("df_baseline", df_baseline)
for row in df_target.iter_rows(named=True):
path_target = Path(args.target_path) / row["filename"]
row_baseline = find_row(
df_baseline,
conditions=dict(
forward_pass_id=row["forward_pass_id"]
- args.start_id
+ args.baseline_start_id,
**{
k: v
for k, v in row.items()
if k not in ["forward_pass_id", "dump_index", "filename"]
},
),
)
if row_baseline is None:
print(f"Skip: target={str(path_target)} since no baseline")
x_target = _load_object(path_target)
if x_target is not None:
print(f"x_target(sample)={get_truncated_value(x_target)}")
continue
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
check_tensor_pair(
path_baseline=path_baseline, path_target=path_target, name=row["name"]
)
print()
def check_tensor_pair(path_baseline, path_target, name=""):
x_baseline = _load_object(path_baseline)
x_target = _load_object(path_target)
print(
f"Raw "
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
)
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
print(
f"After preprocessor "
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
)
x_target = x_target.float()
x_baseline = x_baseline.float()
for name, fn in (
("mean", torch.mean),
("std", torch.std),
("min", torch.min),
("max", torch.max),
("p1", functools.partial(torch.quantile, q=0.01)),
("p5", functools.partial(torch.quantile, q=0.05)),
("p95", functools.partial(torch.quantile, q=0.95)),
("p99", functools.partial(torch.quantile, q=0.99)),
):
value_baseline = fn(x_baseline).item()
value_target = fn(x_target).item()
print(
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
)
if x_baseline.shape != x_target.shape:
print(f"⚠️ Shape mismatch")
return
raw_abs_diff = (x_target - x_baseline).abs()
max_abs_diff = raw_abs_diff.max().item()
mean_abs_diff = raw_abs_diff.mean().item()
rel_diff = _calc_rel_diff(x_target, x_baseline)
needs_print = max_abs_diff > 1e-3
print(
"\t".join(
f"{'' if value > 1e-3 else ''} {name}={value}"
for name, value in [
("rel_diff", rel_diff),
("max_abs_diff", max_abs_diff),
("mean_abs_diff", mean_abs_diff),
]
)
)
if needs_print:
print(f"x_baseline(sample)={get_truncated_value(x_baseline)}")
print(f"x_target(sample)={get_truncated_value(x_target)}")
def _try_unify_shape(x: torch.Tensor, target_shape):
x_shape = x.shape
num_dim_to_remove = len(x_shape) - len(target_shape)
if (x_shape[num_dim_to_remove:] == target_shape) and all(
val == 1 for val in x_shape[:num_dim_to_remove]
):
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
return out
return x
# Copied from DeepGEMM
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def _comparison_preprocessor(x_baseline, x_target, name):
# can insert arbitrary adhoc postprocessing logic here
return x_baseline, x_target
def _load_object(path):
x = torch.load(path, weights_only=False)
if not isinstance(x, torch.Tensor):
print(f"Skip load {path} since {type(x)=} is not a Tensor")
return None
return x.cuda()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--baseline-path", type=str)
parser.add_argument("--target-path", type=str)
parser.add_argument("--start-id", type=int, default=0)
parser.add_argument("--end-id", type=int, default=1000000)
parser.add_argument("--baseline-start-id", type=int, default=0)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,97 @@
import functools
import os
from pathlib import Path
from typing import Any, Dict
import polars as pl
import torch
class DumpLoader:
def __init__(self):
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
self._enable = directory is not None
if self._enable:
self._directory = Path(directory)
self._df = read_meta(directory)
@property
def enable(self):
return self._enable
def load(self, name, **kwargs):
assert self._enable, "Please call DumpLoader.load only when it is enabled"
from sglang.srt.debug_utils.dumper import dumper
forward_pass_id = dumper._forward_pass_id
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
row = find_row(self._df, conditions=conditions)
assert (
row is not None
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
path = self._directory / row["filename"]
output = torch.load(path, weights_only=False)
print(
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
)
return output
def read_meta(directory):
directory = Path(directory)
assert directory.is_dir(), f"{directory=} should be a directory"
rows = []
for p in directory.glob("*.pt"):
full_kwargs = {}
for kv in p.stem.split("___"):
k, v = kv.split("=")
full_kwargs[k] = v
rows.append(
{
"filename": str(p.name),
**full_kwargs,
}
)
df = pl.DataFrame(rows)
df = df.with_columns(
pl.col("forward_pass_id").cast(int),
pl.col("rank").cast(int),
pl.col("dump_index").cast(int),
)
return df
def find_row(df, conditions: Dict[str, Any]):
df_sub = df.filter(
functools.reduce(
lambda a, b: a & b,
[
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
for col in conditions.keys()
],
)
)
assert len(df_sub) <= 1
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
def _cast_to_polars_dtype(value, target_dtype):
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
return int(value)
elif target_dtype in (pl.Float64, pl.Float32):
return float(value)
elif target_dtype == pl.Boolean:
return bool(value)
elif target_dtype == pl.String:
return str(value)
else:
return value
dump_loader = DumpLoader()

View File

@@ -0,0 +1,116 @@
import os
import time
from pathlib import Path
from typing import Optional
import torch
import torch.distributed as dist
class _Dumper:
"""Utility to dump tensors, which can be useful when comparison checking models.
Example usage:
dumper.on_forward_pass_start()
dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id)
Import from non-SGLang system:
```
import sys
sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils")
from dumper import dumper
```
Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
"""
def __init__(self):
# Do not import `sglang` to make this file standalone
self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1")))
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
self._enable_write_file = bool(
int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1"))
)
self._partial_name: Optional[str] = None
self._dump_index = 0
self._forward_pass_id = 0
def on_forward_pass_start(self):
self._forward_pass_id += 1
print(
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
)
def dump(self, name, value, **kwargs):
if not self._enable:
return
assert (
self._forward_pass_id >= 1
), "Do you forget to call `dumper.on_forward_pass_start()`?"
self._dump_index += 1
if self._partial_name is None:
self._partial_name = _get_partial_name()
rank = _get_rank()
full_kwargs = dict(
forward_pass_id=self._forward_pass_id,
rank=rank,
name=name,
dump_index=self._dump_index,
**kwargs,
)
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename
sample_value = get_truncated_value(value)
print(
f"[Dumper] [{rank}, {time.time()}] {path} "
f"type={type(value)} "
f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
f"sample_value={sample_value}"
)
if self._enable_write_file:
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(value, str(path))
def _get_partial_name():
rank = _get_rank()
object_list = [str(time.time()) if rank == 0 else None]
if dist.is_initialized():
dist.broadcast_object_list(object_list, device="cuda")
return object_list[0]
def _get_rank():
if dist.is_initialized():
return dist.get_rank()
else:
return 0
def get_truncated_value(value):
if value is None:
return None
if isinstance(value, tuple):
return [get_truncated_value(x) for x in value]
if not isinstance(value, torch.Tensor):
return None
if value.numel() < 200:
return value
slices = [
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
]
return value[tuple(slices)]
dumper = _Dumper()

View File

@@ -0,0 +1,234 @@
import argparse
import hashlib
import json
from pathlib import Path
import polars as pl
_DESCRIPTION = """Compare and find differences to benchmark outputs.
Supported inputs:
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
"""
def main(args):
if args.data_type == "simple_evals":
df_input = _compute_df_input_mode_simple_evals(args)
else:
df_input = _transform_df_input(_compute_df_raw(args))
assert all(
c in df_input.columns
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
)
df_meta = _compute_df_meta(df_input)
df_correctness_per_trial = df_input.group_by(
"category", "trial_index", maintain_order=True
).agg(pl.col("correct").mean())
df_correctness_delta = (
df_meta.group_by("correctness_delta").len().sort("correctness_delta")
)
df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)
print(f"Dump output to {args.output_path}")
Path(args.output_path).write_text(
json.dumps(
dict(
df_meta=df_meta.to_dicts(),
df_good_to_bad=df_good_to_bad.to_dicts(),
df_bad_to_good=df_bad_to_good.to_dicts(),
),
indent=4,
),
)
if not args.disable_print_details:
with pl.Config(
fmt_str_lengths=10000,
tbl_cols=-1,
tbl_rows=-1,
tbl_width_chars=-1,
tbl_formatting="UTF8_FULL",
):
print("====== Correctness per trial ======")
print(df_correctness_per_trial)
print(
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
)
print(df_correctness_delta)
for name, df in [
("Good->Bad", df_good_to_bad),
("Bad->Good", df_bad_to_good),
]:
print(f"====== Concrete Examples: {name} ======")
print(df)
def _compute_df_input_mode_simple_evals(args):
return pl.concat(
[
_compute_df_input_one_mode_simple_evals(**info)
for info in _get_file_infos(args=args)
]
)
def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
data = json.loads(Path(path).read_text())
rows = []
for single_eval_result in data["metadata"]["single_eval_results"]:
prompt = single_eval_result["example_level_metadata"][
"actual_queried_prompt_messages"
]
score = single_eval_result["score"]
assert score in {0.0, 1.0}, f"{score=}"
row = dict(
category=category,
trial_index=trial_index,
prompt_id=_compute_id_from_object(prompt),
prompt=json.dumps(prompt),
output=single_eval_result["example_level_metadata"]["response_text"],
correct=score == 1.0,
)
rows.append(row)
return pl.DataFrame(rows)
def _compute_id_from_object(obj):
if isinstance(obj, pl.Series):
obj = obj.to_list()
json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
def _compute_df_raw(args):
return pl.concat(
[
_read_df_raw(
path=info["path"],
category=info["category"],
trial_index=info["trial_index"],
)
for info in _get_file_infos(args=args)
]
)
def _get_file_infos(args):
return [
dict(path=path, category=category, trial_index=trial_index)
for category, paths in [
("baseline", args.baseline_path),
("target", args.target_path),
]
for trial_index, path in enumerate(paths)
]
def _read_df_raw(path: str, category: str, trial_index: int):
return pl.read_ndjson(path).with_columns(
category=pl.lit(category), trial_index=trial_index
)
def _transform_df_input(df: pl.DataFrame):
if "doc_id" in df.columns:
print("Transform mode: lm_eval")
filter_names = df["filter"].unique(maintain_order=True).to_list()
if len(filter_names) > 1:
filter_name = filter_names[0]
print(f"Choose {filter_name=} among {filter_names}")
df = df.filter(pl.col("filter") == filter_name)
df = df.select(
pl.col("category"),
pl.col("trial_index"),
prompt_id=pl.col("doc_id"),
prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
output=pl.col("resps").list.get(0).list.get(0),
correct=pl.col("exact_match").cast(bool),
)
return df
elif "prompt_id" in df.columns:
print("Transform mode: SGLang bench")
return df
else:
raise Exception(
f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
)
def _compute_df_meta(df_input: pl.DataFrame):
df_input = df_input.sort("prompt_id", "category", "trial_index")
df_meta = pl.DataFrame(
[
_handle_one_prompt(df_one_prompt)
for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
]
)
df_meta = df_meta.with_columns(
correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
)
df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
return df_meta
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
assert (
len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
)
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
df_target = df_one_prompt.filter(pl.col("category") == "target")
outputs_baseline = df_baseline["output"].to_list()
outputs_target = df_target["output"].to_list()
output_same_prefix_len = max(
_compute_str_prefix_len(output_baseline, output_target)
for output_baseline in outputs_baseline
for output_target in outputs_target
)
return dict(
prompt_id=df_one_prompt[0, "prompt_id"],
correctness_baseline=df_baseline["correct"].mean(),
correctness_target=df_target["correct"].mean(),
output_same_prefix_len=output_same_prefix_len,
prompt=df_one_prompt[0, "prompt"],
outputs_baseline=outputs_baseline,
outputs_target=outputs_target,
)
def _compute_str_prefix_len(a: str, b: str) -> int:
min_len = min(len(a), len(b))
for i in range(min_len):
if a[i] != b[i]:
return i
return min_len
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("--data-type", type=str, default="auto")
parser.add_argument("--baseline-path", type=str, nargs="+")
parser.add_argument("--target-path", type=str, nargs="+")
parser.add_argument(
"--output-path", type=str, default="/tmp/text_comparator_output.json"
)
parser.add_argument("--disable-print-details", action="store_true")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,6 @@
from sglang.srt.disaggregation.ascend.conn import (
AscendKVBootstrapServer,
AscendKVManager,
AscendKVReceiver,
AscendKVSender,
)

View File

@@ -0,0 +1,117 @@
import concurrent.futures
import logging
from typing import List, Tuple
import numpy as np
import numpy.typing as npt
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)
from sglang.srt.utils import get_local_ip_by_remote
logger = logging.getLogger(__name__)
class AscendKVManager(MooncakeKVManager):
def init_engine(self):
# TransferEngine initialized on ascend.
local_ip = get_local_ip_by_remote()
self.engine = AscendTransferEngine(
hostname=local_ip,
npu_id=self.kv_args.gpu_id,
disaggregation_mode=self.disaggregation_mode,
)
def register_buffer_to_engine(self):
self.engine.batch_register(self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens)
# The Ascend backend optimize batch registration for small memory blocks.
self.engine.batch_register(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
# Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)
num_layers = len(self.kv_args.kv_data_ptrs)
layers_params = [
(
self.kv_args.kv_data_ptrs[layer_id],
dst_kv_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
)
for layer_id in range(num_layers)
]
def set_transfer_blocks(
src_ptr: int, dst_ptr: int, item_len: int
) -> List[Tuple[int, int, int]]:
transfer_blocks = []
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index)
transfer_blocks.append((src_addr, dst_addr, length))
return transfer_blocks
# Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
return self._transfer_data(mooncake_session_id, transfer_blocks)
# Worker function for processing all layers in a batch
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
transfer_blocks = []
for src_ptr, dst_ptr, item_len in layers_params:
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
return self._transfer_data(mooncake_session_id, transfer_blocks)
if self.enable_custom_mem_pool:
futures = [
executor.submit(
process_layer,
src_ptr,
dst_ptr,
item_len,
)
for (src_ptr, dst_ptr, item_len) in layers_params
]
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
for f in futures:
f.cancel()
return status
else:
# Combining all layers' params in one batch transfer is more efficient
# compared to using multiple threads
return process_layers(layers_params)
return 0
class AscendKVSender(MooncakeKVSender):
pass
class AscendKVReceiver(MooncakeKVReceiver):
pass
class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
pass

View File

@@ -0,0 +1,58 @@
import logging
import os
from typing import List, Optional
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
logger = logging.getLogger(__name__)
class AscendTransferEngine(MooncakeTransferEngine):
def __init__(
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
):
try:
from mf_adapter import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
) from e
self.engine = TransferEngine()
self.hostname = hostname
self.npu_id = npu_id
# Centralized storage address of the AscendTransferEngine
self.store_url = os.getenv("ASCEND_MF_STORE_URL")
if disaggregation_mode == DisaggregationMode.PREFILL:
self.role = "Prefill"
elif disaggregation_mode == DisaggregationMode.DECODE:
self.role = "Decode"
else:
logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}")
raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}")
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
self.initialize()
def initialize(self) -> None:
"""Initialize the ascend transfer instance."""
ret_value = self.engine.initialize(
self.store_url,
self.session_id,
self.role,
self.npu_id,
)
if ret_value != 0:
logger.error("Ascend Transfer Engine initialization failed.")
raise RuntimeError("Ascend Transfer Engine initialization failed.")
def batch_register(self, ptrs: List[int], lengths: List[int]):
try:
ret_value = self.engine.batch_register_memory(ptrs, lengths)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0:
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")

View File

@@ -0,0 +1,8 @@
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)

View File

@@ -0,0 +1,134 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import numpy.typing as npt
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.disaggregation.utils import DisaggregationMode
class KVArgs:
engine_rank: int
kv_data_ptrs: List[int]
kv_data_lens: List[int]
kv_item_lens: List[int]
aux_data_ptrs: List[int]
aux_data_lens: List[int]
aux_item_lens: List[int]
ib_device: str
ib_traffic_class: str
gpu_id: int
# for different tp
decode_tp_size: int
kv_head_num: int
page_size: int
# for pp prefill
prefill_pp_size: int
pp_rank: int
prefill_start_layer: int
# for system dp
system_dp_rank: int
class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4
class BaseKVManager(ABC):
"""Base class for managing transfers states"""
@abstractmethod
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
): ...
class BaseKVSender(ABC):
@abstractmethod
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
): ...
@abstractmethod
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
"""
Notify the decoder server about the kv indices length and aux index
"""
...
@abstractmethod
def send(self, kv_indices: npt.NDArray[np.int32]):
"""
Send the kv cache at the given kv indices to the decoder server
"""
...
@abstractmethod
def poll(self) -> KVPoll:
"""
Check the status of the kv cache transfer
"""
...
@abstractmethod
def failure_exception(self):
"""
Raise an exception if the kv cache transfer fails
"""
...
class BaseKVReceiver(ABC):
@abstractmethod
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
): ...
@abstractmethod
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
"""
Notify the prefill server about the kv indices and aux index
"""
...
@abstractmethod
def poll(self) -> KVPoll:
"""
Check the status of the kv cache transfer
"""
...
@abstractmethod
def failure_exception(self):
"""
Raise an exception if the kv cache transfer fails
"""
...
class BaseKVBootstrapServer(ABC):
@abstractmethod
def __init__(self, host: str, port: int): ...

View File

@@ -0,0 +1,5 @@
from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
)

View File

@@ -0,0 +1,438 @@
from __future__ import annotations
import asyncio
import logging
import socket
import threading
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_ip,
get_local_ip_by_remote,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
logger = logging.getLogger(__name__)
class CommonKVManager(BaseKVManager):
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
):
self.kv_args = args
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer
self.bootstrap_host = server_args.host
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.enable_dp_attention = server_args.enable_dp_attention
if not server_args.enable_dp_attention and server_args.dp_size != 1:
raise ValueError(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self.rank_port = get_free_port()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
# multi node: bootstrap server's host is dist_init_addr
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
# single node: bootstrap server's host is same as http server's host
host = self.bootstrap_host
host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"tp_size": self.tp_size,
"dp_size": self.dp_size,
"rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
}
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
class CommonKVReceiver(BaseKVReceiver):
_ctx = zmq.Context()
_socket_cache = {}
_socket_locks = {}
_global_lock = threading.Lock()
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
self._get_prefill_dp_size_from_server()
)
if self.prefill_tp_size is None or self.prefill_dp_size is None:
logger.error(
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
)
else:
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size
)
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
else:
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
)
self.required_dst_info_num = 1
self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
)
self.target_tp_ranks = [self.target_tp_rank]
else:
assert (
self.kv_mgr.is_mla_backend
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self.target_tp_ranks = [
rank
for rank in range(
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
if prefill_dp_rank is not None:
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
self.prefill_dp_rank = prefill_dp_rank
else:
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self.target_dp_group = self.prefill_dp_rank
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
)
if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks:
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank,
self.target_dp_group,
)
if bootstrap_info is not None:
# NOTE: only support MLA for now: select one prefill rank as real rank
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
bootstrap_infos.append(bootstrap_info)
else:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
)
self.bootstrap_infos = bootstrap_infos
if len(self.bootstrap_infos) == 0:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
response = requests.get(url)
if response.status_code == 200:
bootstrap_info = response.json()
return bootstrap_info
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
def _get_prefill_dp_size_from_server(self) -> int:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return int(prefill_parallel_info["prefill_tp_size"]), int(
prefill_parallel_info["prefill_dp_size"]
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None
@classmethod
def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock
def _register_kv_args(self):
pass
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class CommonKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.tp_size = None
self.dp_size = None
self.tp_size_per_dp_rank = None
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
tp_size = data["tp_size"]
dp_size = data["dp_size"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])
if self.tp_size is None:
self.tp_size = tp_size
if self.dp_size is None:
self.dp_size = dp_size
tp_size_per_dp_rank = tp_size // dp_size
if self.tp_size_per_dp_rank == None:
self.tp_size_per_dp_rank = tp_size_per_dp_rank
# Add lock to make sure thread-safe
if role == "Prefill":
dp_group = engine_rank // tp_size_per_dp_rank
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
target_dp_group = request.query.get("target_dp_group")
if not engine_rank or not target_dp_group:
return web.Response(text="Missing inputs for bootstrap server.", status=400)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if int(engine_rank) == -1 and int(target_dp_group) == -1:
prefill_parallel_info = {
"prefill_tp_size": self.tp_size,
"prefill_dp_size": self.dp_size,
}
return web.json_response(prefill_parallel_info, status=200)
# Find corresponding prefill info
async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
int(engine_rank)
]
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Bootstrap info not Found", status=404)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, host=self.host, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ...

View File

@@ -0,0 +1,42 @@
import threading
from collections import deque
from typing import List, Tuple
import numpy as np
import numpy.typing as npt
class FastQueue:
def __init__(self):
self._buf = deque()
self._cond = threading.Condition()
def put(self, item):
with self._cond:
self._buf.append(item)
# wake up a thread of wait()
self._cond.notify()
def get(self):
with self._cond:
# if queue is empty ,block until is notified()
while not self._buf:
self._cond.wait()
return self._buf.popleft()
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]
) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups

View File

@@ -0,0 +1,894 @@
"""
Life cycle of a request in the decode server
1. PreallocQueue:
a. Initialize a receiver for each request
b. The request handshakes first, and pre-allocate kv once there is available kv.
c. Move the request to TransferQueue.
2. TransferQueue:
a. Poll the receiver to check the transfer state
b. If the transfer has finished, move the request to waiting queue
3. WaitingQueue:
a. Use the requests in the queue to construct a PrebuiltExtendBatch
b. Skip the prefill forward but only populate metadata
4. RunningBatch:
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
"""
from __future__ import annotations
import logging
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
import torch
from torch.distributed import ProcessGroup
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
is_mla_backend,
kv_to_page_indices,
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import get_int_env_var, require_mlp_sync
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import Scheduler
CLIP_MAX_NEW_TOKEN = get_int_env_var("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", 4096)
class DecodeReqToTokenPool:
"""
The difference of DecodeReqToTokenPool and ReqToTokenPool is that
DecodeReqToTokenPool subscribes memory for pre-allocated requests.
In ReqToTokenPool, if `--max-running-requests` is 8,
#pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.
In DecodeReqToTokenPool, if `--max-running-requests` is 8,
#running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.
"""
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
pre_alloc_size: int,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.size = size
self.max_context_len = max_context_len
self.device = device
self.pre_alloc_size = pre_alloc_size
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros(
(size + pre_alloc_size, max_context_len),
dtype=torch.int32,
device=device,
)
self.free_slots = list(range(size + pre_alloc_size))
def write(self, indices, values):
self.req_to_token[indices] = values
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
def clear(self):
self.free_slots = list(range(self.size + self.pre_alloc_size))
@dataclass
class DecodeRequest:
req: Req
kv_receiver: BaseKVReceiver
waiting_for_input: bool = False
metadata_buffer_index: int = -1
class DecodePreallocQueue:
"""
Store the requests that are preallocating.
"""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: MetadataBuffers,
scheduler: Scheduler,
transfer_queue: DecodeTransferQueue,
tree_cache: BasePrefixCache,
gloo_group: ProcessGroup,
tp_rank: int,
tp_size: int,
dp_size: int,
gpu_id: int,
bootstrap_port: int,
max_total_num_tokens: int,
prefill_pp_size: int,
num_reserved_decode_tokens: int,
transfer_backend: TransferBackend,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.scheduler = scheduler
self.transfer_queue = transfer_queue
self.tree_cache = tree_cache # this is always a chunk cache
self.gloo_group = gloo_group
self.tp_rank = tp_rank
self.tp_size = tp_size
self.dp_size = dp_size
self.gpu_id = gpu_id
self.bootstrap_port = bootstrap_port
self.max_total_num_tokens = max_total_num_tokens
self.prefill_pp_size = prefill_pp_size
self.num_reserved_decode_tokens = num_reserved_decode_tokens
self.transfer_backend = transfer_backend
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.retracted_queue: List[Req] = []
self.prefill_pp_size = prefill_pp_size
self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> BaseKVManager:
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class()
attn_tp_size = get_attention_tp_size()
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
kv_args.decode_tp_size = attn_tp_size
# Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0
kv_args.pp_rank = 0
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.prefill_pp_size = self.prefill_pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs += draft_kv_data_ptrs
kv_data_lens += draft_kv_data_lens
kv_item_lens += draft_kv_item_lens
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class: Type[BaseKVManager] = get_kv_class(
self.transfer_backend, KVClassType.MANAGER
)
kv_manager: BaseKVManager = kv_manager_class(
kv_args,
DisaggregationMode.DECODE,
self.scheduler.server_args,
self.is_mla_backend,
)
return kv_manager
def add(self, req: Req, is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
if self._check_if_req_exceed_kv_capacity(req):
return
if is_retracted:
self.retracted_queue.append(req)
else:
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
kv_receiver_class = get_kv_class(
TransferBackend.FAKE, KVClassType.RECEIVER
)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
prefill_dp_rank=req.data_parallel_rank,
)
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message)
prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
self.scheduler.stream_output([req], req.return_logprob)
return True
return False
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
for req in reqs:
self.add(req, is_retracted=is_retracted)
def resume_retracted_reqs(self) -> List[Req]:
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
# allocate memory
resumed_reqs = []
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
for i, req in enumerate(self.retracted_queue):
if self.req_to_token_pool.available_size() <= 0:
break
required_tokens_for_request = (
len(req.origin_input_ids)
+ len(req.output_ids)
+ self.num_reserved_decode_tokens
)
if required_tokens_for_request > allocatable_tokens:
break
resumed_reqs.append(req)
indices_to_remove.add(i)
req.is_retracted = False
self._pre_alloc(req)
allocatable_tokens -= required_tokens_for_request
# load from cpu, release the cpu copy
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
self.retracted_queue = [
entry
for i, entry in enumerate(self.retracted_queue)
if i not in indices_to_remove
]
return resumed_reqs
def _update_handshake_waiters(self) -> None:
if not self.queue:
return
if all(decode_req.waiting_for_input for decode_req in self.queue):
return
polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Bootstrapping:
pass
elif poll == KVPoll.WaitingForInput:
decode_req.waiting_for_input = True
elif poll == KVPoll.Failed:
error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
decode_req.kv_receiver.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.error(error_message)
prepare_abort(
decode_req.req,
error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
if self.scheduler.enable_metrics:
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
else:
raise ValueError(f"Unexpected poll case: {poll}")
def pop_preallocated(self) -> List[DecodeRequest]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
self._update_handshake_waiters()
preallocated_reqs = []
indices_to_remove = set()
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
retractable_tokens = sum(
len(r.origin_input_ids) + len(r.output_ids)
for r in self.scheduler.running_batch.reqs
)
allocatable_tokens = self._allocatable_tokens(
retractable_tokens=retractable_tokens, count_retracted=True
)
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
indices_to_remove.add(i)
# Then, preallocate the remaining requests if possible
for i, decode_req in enumerate(self.queue):
if i in indices_to_remove:
continue
if not decode_req.waiting_for_input:
continue
if self.req_to_token_pool.available_size() <= 0:
break
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
break
# Memory estimation: don't add if the projected memory cannot be met
# TODO: add new_token ratio
origin_input_len = len(decode_req.req.origin_input_ids)
required_tokens_for_request = (
origin_input_len + self.num_reserved_decode_tokens
)
if (
max(
required_tokens_for_request,
origin_input_len
+ min(
decode_req.req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKEN,
)
- retractable_tokens,
)
> allocatable_tokens
):
break
if required_tokens_for_request > allocatable_tokens:
break
allocatable_tokens -= required_tokens_for_request
self._pre_alloc(decode_req.req)
kv_indices = (
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
: len(decode_req.req.origin_input_ids)
]
.cpu()
.numpy()
)
decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert decode_req.metadata_buffer_index is not None
page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
return preallocated_reqs
@property
def num_tokens_pre_allocated(self):
return sum(
len(decode_req.req.fill_ids) for decode_req in self.transfer_queue.queue
)
def _allocatable_tokens(
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
) -> int:
need_space_for_single_req = (
max(
[
min(x.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKEN)
+ len(x.origin_input_ids)
- retractable_tokens
for x in self.scheduler.running_batch.reqs
]
)
if retractable_tokens is not None
and len(self.scheduler.running_batch.reqs) > 0
else 0
)
if self.scheduler.model_config.is_hybrid:
available_size = min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()
allocatable_tokens = available_size - max(
# preserve some space for future decode
self.num_reserved_decode_tokens
* (
len(self.scheduler.running_batch.reqs)
+ len(self.transfer_queue.queue)
+ len(self.scheduler.waiting_queue)
),
# make sure each request can finish if reach max_tokens with all other requests retracted
need_space_for_single_req,
)
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
# the extend batch is not in any queue, so we need to explicitly add the tokens slots here
if (
self.scheduler.last_batch
and self.scheduler.last_batch.forward_mode.is_extend()
):
allocatable_tokens -= self.num_reserved_decode_tokens * len(
self.scheduler.last_batch.reqs
)
if count_retracted:
allocatable_tokens -= sum(
[
len(req.origin_input_ids)
+ len(req.output_ids)
+ self.num_reserved_decode_tokens
for req in self.retracted_queue
]
)
return allocatable_tokens
def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1)
assert (
req_pool_indices is not None
), "req_pool_indices is full! There is a bug in memory estimation."
req.req_pool_idx = req_pool_indices[0]
if self.token_to_kv_pool_allocator.page_size == 1:
kv_loc = self.token_to_kv_pool_allocator.alloc(
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
)
else:
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens=torch.tensor(
[0],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
seq_lens=torch.tensor(
[num_tokens],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
last_loc=torch.tensor(
[-1],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
extend_num_tokens=num_tokens,
)
assert (
kv_loc is not None
), "KV cache is full! There is a bug in memory estimation."
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
# populate metadata
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = len(req.origin_input_ids)
return kv_loc
class DecodeTransferQueue:
"""
Store the requests that is polling kv
"""
def __init__(
self,
gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
tp_rank: int,
metadata_buffers: MetadataBuffers,
scheduler: Scheduler,
tree_cache: BasePrefixCache,
):
self.queue: List[DecodeRequest] = []
self.gloo_group = gloo_group
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.metadata_buffers = metadata_buffers
self.scheduler = scheduler
self.tree_cache = tree_cache
self.spec_algorithm = scheduler.spec_algorithm
def add(self, decode_req: DecodeRequest) -> None:
self.queue.append(decode_req)
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
self.queue.extend(decode_reqs)
def pop_transferred(self) -> List[Req]:
if not self.queue:
return []
polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
transferred_reqs = []
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
decode_req.kv_receiver.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.error(error_message)
prepare_abort(
decode_req.req,
error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
# unlock the kv cache or it will have memory leak
self.tree_cache.cache_finished_req(decode_req.req)
indices_to_remove.add(i)
if self.scheduler.enable_metrics:
self.scheduler.metrics_collector.increment_transfer_failed_reqs()
continue
elif poll == KVPoll.Success:
idx = decode_req.metadata_buffer_index
(
output_id,
output_token_logprobs_val,
output_token_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
output_hidden_states,
) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item())
if not self.spec_algorithm.is_none():
decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append(
output_token_logprobs_val[0].item()
)
decode_req.req.output_token_logprobs_idx.append(
output_token_logprobs_idx[0].item()
)
decode_req.req.output_top_logprobs_val.append(
output_top_logprobs_val[
: decode_req.req.top_logprobs_num
].tolist()
)
decode_req.req.output_top_logprobs_idx.append(
output_top_logprobs_idx[
: decode_req.req.top_logprobs_num
].tolist()
)
if hasattr(decode_req.kv_receiver, "clear"):
decode_req.kv_receiver.clear()
# special handling for sampling_params.max_new_tokens == 1
if decode_req.req.sampling_params.max_new_tokens == 1:
# finish immediately
decode_req.req.check_finished()
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
self.tree_cache.cache_finished_req(decode_req.req)
else:
transferred_reqs.append(decode_req.req)
indices_to_remove.add(i)
elif poll in [
KVPoll.Bootstrapping,
KVPoll.WaitingForInput,
KVPoll.Transferring,
]:
pass
else:
raise ValueError(f"Unexpected poll case: {poll}")
for i in indices_to_remove:
idx = self.queue[i].metadata_buffer_index
assert idx != -1
self.req_to_metadata_buffer_idx_allocator.free(idx)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
return transferred_reqs
class SchedulerDisaggregationDecodeMixin:
@torch.no_grad()
def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_mlp_sync_flag:
self._prepare_idle_batch_and_run(None)
else:
if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch)
self.process_batch_result(batch, result)
elif prepare_mlp_sync_flag:
batch, _ = self._prepare_idle_batch_and_run(None)
if batch is None and (
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
self.self_check_during_idle()
self.last_batch = batch
@torch.no_grad()
def event_loop_overlap_disagg_decode(self: Scheduler):
result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
last_batch_in_queue = False
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_mlp_sync_flag:
batch_, result = self._prepare_idle_batch_and_run(
None, delay_process=True
)
if batch_:
result_queue.append((batch_.copy(), result))
last_batch_in_queue = True
else:
if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
if (self.last_batch is None) or (not self.last_batch_in_queue):
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
last_batch_in_queue = True
elif prepare_mlp_sync_flag:
batch, result = self._prepare_idle_batch_and_run(
None, delay_process=True
)
if batch:
result_queue.append((batch.copy(), result))
last_batch_in_queue = True
# Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and self.last_batch_in_queue:
tmp_batch, tmp_result = result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
if batch is None and (
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
self.self_check_during_idle()
self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
batch = self.prepare_mlp_sync_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]:
"""Create fake completed prefill if possible and merge with running batch"""
# Merge the prefill batch into the running batch
last_batch = self.last_batch
if last_batch and last_batch.forward_mode.is_extend():
# chunked prefill doesn't happen in decode instance.
assert self.chunked_req is None
# Filter finished batches.
last_batch.filter_batch()
if not last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = last_batch
else:
# merge running_batch with prefill batch
self.running_batch.merge_batch(last_batch)
new_prebuilt_batch = self.get_new_prebuilt_batch()
ret: Optional[ScheduleBatch] = None
if new_prebuilt_batch:
ret = new_prebuilt_batch
else:
if self.running_batch.is_empty():
ret = None
else:
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch if not self.running_batch.is_empty() else None
return ret
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
"""Create a schedulebatch for fake completed prefill"""
if self.grammar_queue:
self.move_ready_grammar_requests()
if len(self.waiting_queue) == 0:
return None
curr_batch_size = self.running_batch.batch_size()
batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
num_not_used_batch = batch_size - curr_batch_size
# pop req from waiting queue
can_run_list: List[Req] = []
waiting_queue: List[Req] = []
for i in range(len(self.waiting_queue)):
req = self.waiting_queue[i]
# we can only add at least `num_not_used_batch` new batch to the running queue
if i < num_not_used_batch:
can_run_list.append(req)
req.init_next_round_input(self.tree_cache)
else:
waiting_queue.append(req)
self.waiting_queue = waiting_queue
if len(can_run_list) == 0:
return None
# construct a schedule batch with those requests and mark as decode
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
self.spec_algorithm,
)
# construct fake completed prefill
new_batch.prepare_for_prebuilt_extend()
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
return new_batch
def process_decode_queue(self: Scheduler):
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
self.waiting_queue.extend(resumed_reqs)
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
# if there are still retracted requests, we do not allocate new requests
return
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING
import torch
from sglang.srt.disaggregation.utils import prepare_abort
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.server_args import ServerArgs
class ScheduleBatchDisaggregationDecodeMixin:
def prepare_for_prebuilt_extend(self: ScheduleBatch):
"""
Prepare a prebuilt extend by populate metadata
Adapted from .prepare_for_extend().
"""
self.forward_mode = ForwardMode.EXTEND
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = []
pre_lens = []
req_pool_indices = []
# Pre-calculate total size
total_size = sum(req.extend_input_len for req in reqs)
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
# Fill the tensor in one pass
offset = 0
for i, req in enumerate(reqs):
req_pool_indices.append(req.req_pool_idx)
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: req.extend_input_len
]
assert (
offset + req.extend_input_len <= total_size
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
out_cache_loc[offset : offset + req.extend_input_len] = chunk
offset += req.extend_input_len
pre_len = len(req.prefix_indices)
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
seq_lens.append(seq_len)
if len(req.output_ids) == 0:
assert (
seq_len - pre_len == req.extend_input_len
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
req.extend_logprob_start_len = 0
extend_input_logprob_token_ids = None
# Set fields
self.input_ids = torch.tensor(
sum(input_ids, []), dtype=torch.int32, device=self.device
)
self.req_pool_indices = torch.tensor(
req_pool_indices, dtype=torch.int64, device=self.device
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.orig_seq_lens = torch.tensor(
seq_lens, dtype=torch.int32, device=self.device
)
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
self.multimodal_inputs = [r.multimodal_inputs for r in reqs]
# Build sampling info
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
)
def process_prebuilt_extend(
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
):
"""Assign the buffered last input id to schedule batch"""
self.output_ids = []
for req in self.reqs:
self.output_ids.append(req.output_ids[-1])
self.tree_cache.cache_unfinished_req(req)
if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue.
try:
req.grammar.accept_token(req.output_ids[-1])
except ValueError as e:
# Grammar accept_token can raise ValueError if the token is not in the grammar.
# This can happen if the grammar is not set correctly or the token is invalid.
error_message = f"Grammar accept_token failed for req {req.rid} with token {req.output_ids[-1]}: {e}"
self.tree_cache.cache_finished_req(req)
prepare_abort(
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
req.grammar.finished = req.finished()
self.output_ids = torch.tensor(self.output_ids, device=self.device)
# Simulate the eagle run. We add mock data to hidden states for the
# ease of implementation now meaning the first token will have acc rate
# of 0.
if not self.spec_algorithm.is_none():
b = len(self.reqs)
topk_p = torch.arange(
b * server_args.speculative_eagle_topk,
0,
-1,
device=self.device,
dtype=torch.float32,
)
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
topk_p /= b * server_args.speculative_eagle_topk
topk_index = torch.arange(
b * server_args.speculative_eagle_topk, device=self.device
)
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
# local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
verified_id=self.output_ids,
)
spec_info.prepare_for_extend(self)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
self.spec_info = spec_info

View File

@@ -0,0 +1 @@
from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender

View File

@@ -0,0 +1,85 @@
import logging
from typing import List, Optional
import numpy as np
import numpy.typing as npt
from sglang.srt.disaggregation.base.conn import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVPoll,
)
logger = logging.getLogger(__name__)
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
class FakeKVSender(BaseKVSender):
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
):
self.has_sent = False
def poll(self) -> KVPoll:
if self.has_sent is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.debug("FakeKVSender poll success")
return KVPoll.Success
def init(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
):
logger.debug(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)
pass
def send(
self,
kv_indices: npt.NDArray[np.int32],
):
self.has_sent = True
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
def failure_exception(self):
raise Exception("Fake KVSender Exception")
class FakeKVReceiver(BaseKVReceiver):
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.has_init = False
def poll(self) -> KVPoll:
if self.has_init is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.debug("FakeKVReceiver poll success")
return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
self.has_init = True
logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")

View File

@@ -0,0 +1,412 @@
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
KV caching events
"""
import atexit
import logging
import queue
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from itertools import count
from queue import Queue
from typing import Any, Callable, Optional, Union
import msgspec
import zmq
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class EventBatch(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False, # type: ignore[call-arg]
):
ts: float
events: list[Any]
attn_dp_rank: Optional[int] = None
class KVCacheEvent(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False, # type: ignore[call-arg]
tag=True,
):
"""Base class for all KV cache-related events"""
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
class AllBlocksCleared(KVCacheEvent):
pass
class KVEventBatch(EventBatch):
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
class EventPublisher(ABC):
"""
Lightweight publisher for EventBatch batches with
support for DP attention.
In DP attention - each rank has its own Scheduler and
KV cache instance in order to avoid duplicate events
and ensure proper event attribution. In our implementation
- Each DP rank has its own EventPublisher
- Publishers annotate events with the dp rank
- This allows consumers to distinguish events from different DP ranks
"""
def __init__(self, attn_dp_rank: int = 0):
self._attn_dp_rank = attn_dp_rank
@abstractmethod
def publish(self, events: EventBatch) -> None:
"""Emit events in order.
Implementations should guarantee at-least-once delivery and
monotonic ordering (e.g., via sequence numbers).
"""
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the publisher."""
class NullEventPublisher(EventPublisher):
"""No-op implementation (default when disabled)."""
def publish(self, events) -> None:
return
def shutdown(self) -> None:
return
class ZmqEventPublisher(EventPublisher):
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
Spawns a separate thread to handle publishing from a queue.
Parameters
----------
endpoint:
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
connect.
replay_endpoint:
Optional ROUTER address for replay requests. When given, subscribers can
request missed batches by sending the starting sequence number as an
8-byte big-endian integer.
buffer_steps:
Number of past batches to keep for replay.
hwm:
ZeroMQ high-water-mark for PUB socket.
max_queue_size:
Maximum number of events to buffer in memory.
topic:
Topic to publish events to.
"""
SHUTDOWN_TIMEOUT: float = 1.0
END_SEQ = (-1).to_bytes(8, "big", signed=True)
def __init__(
self,
attn_dp_rank: int,
endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000,
hwm: int = 100_000,
max_queue_size: int = 100_000,
topic: str = "",
) -> None:
# Storage
super().__init__(attn_dp_rank)
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
# ZMQ sockets
self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None
self._dp_rank = attn_dp_rank
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
self._replay_endpoint = self.offset_endpoint_port(
replay_endpoint, self._dp_rank
)
self._hwm = hwm
self._socket_setup()
# Payload
self._seq_gen = count()
self._topic_bytes = topic.encode("utf-8")
# Thread
self._running = True
logger.info("Starting ZMQ publisher thread")
self._thread = threading.Thread(
target=self._publisher_thread, daemon=True, name="zmq-publisher"
)
self._thread.start()
atexit.register(self.shutdown)
def publish(self, events: EventBatch) -> None:
if not self._running:
raise RuntimeError("Publisher is closed")
if events.attn_dp_rank is None:
events.attn_dp_rank = self._dp_rank
self._event_queue.put(events)
def shutdown(self) -> None:
"""Stop the publisher thread and clean up resources."""
self._running = False
self._event_queue.put_nowait(None)
start = time.time()
pending_items = True
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
pending_items = not self._event_queue.empty()
if pending_items:
time.sleep(0.1)
if pending_items:
logger.warning(
"Warning: Queue still has %s items after %s seconds timeout",
self._event_queue.qsize(),
self.SHUTDOWN_TIMEOUT,
)
if self._thread.is_alive():
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
# Clean up ZMQ resources
try:
if self._pub is not None:
self._pub.close(linger=0)
if self._replay is not None:
self._replay.close(linger=0)
finally:
pass # Do not terminate context; other sockets may use it
def _socket_setup(self) -> None:
"""Initialize sockets
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
"""
if self._pub is None:
self._pub = self._ctx.socket(zmq.PUB)
self._pub.set_hwm(self._hwm)
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
if (
"*" in self._endpoint
or "::" in self._endpoint
or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://")
):
self._pub.bind(self._endpoint)
else:
self._pub.connect(self._endpoint)
# Set up replay socket: use ROUTER
# 1) handles multiple REQ clients (identities)
# 2) lets us send back one request → many replies (streamed events)
# 3) works in our nonblocking poll loop alongside PUB
if self._replay_endpoint is not None:
self._replay = self._ctx.socket(zmq.ROUTER)
self._replay.bind(self._replay_endpoint)
def _publisher_thread(self) -> None:
"""Background thread that processes the event queue."""
self._pack = msgspec.msgpack.Encoder()
assert self._pub is not None # narrows type for mypy
while self._running or self._event_queue.qsize() > 0:
# --- replay (non-critical) ---------------------------------
if self._replay is not None and self._replay.poll(0):
try:
self._service_replay()
except Exception as e:
logger.exception("Error in replay: %s", e)
# --- main queue (critical) ---------------------------------
try:
event = self._event_queue.get(timeout=0.1)
if event is None:
break # Sentinel received, exit thread
except queue.Empty:
continue
try:
seq = next(self._seq_gen)
payload = self._pack.encode(event)
seq_bytes = seq.to_bytes(8, "big")
self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
self._buffer.append((seq, payload))
self._event_queue.task_done()
except Exception as e:
# Publishing failed; back-off a bit to avoid a tight error loop
logger.exception("Error in publisher thread: %s", e)
time.sleep(0.1)
def _service_replay(self) -> None:
"""If a replay request is waiting, send buffered batches."""
assert self._replay is not None # narrows type for mypy
frame = self._replay.recv_multipart()
if len(frame) != 3:
logger.warning("Invalid replay request: %s", frame)
return
client_id, _, start_seq_bytes = frame
start_seq = int.from_bytes(start_seq_bytes, "big")
for seq, buf in self._buffer:
if seq >= start_seq:
# [identity, empty_delim, seq_bytes, payload]
# (identity, empty_delim) are stripped off by the router
# receiving payload is (seq_bytes, payload)
self._replay.send_multipart(
(client_id, b"", seq.to_bytes(8, "big"), buf)
)
# Send end of sequence marker
# receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
@staticmethod
def offset_endpoint_port(
endpoint: Optional[str], data_parallel_rank: int
) -> Optional[str]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if not endpoint or data_parallel_rank == 0:
return endpoint
if "inproc" in endpoint:
return f"{endpoint}_dp{data_parallel_rank}"
if "tcp" in endpoint:
if endpoint and ":" in endpoint:
# Get everything after the last colon (the port)
last_colon_idx = endpoint.rfind(":")
base_addr = endpoint[:last_colon_idx]
base_port = int(endpoint[last_colon_idx + 1 :])
new_port = base_port + data_parallel_rank
return f"{base_addr}:{new_port}"
return endpoint
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
class KVEventsConfig(BaseModel):
"""Configuration for KV event publishing."""
publisher: str = "null"
"""The publisher to use for publishing kv events. Can be "null", "zmq".
"""
endpoint: str = "tcp://*:5557"
"""The zmq endpoint to use for publishing kv events.
"""
replay_endpoint: Optional[str] = None
"""The zmq endpoint to use for replaying kv events.
"""
buffer_steps: int = 10_000
"""The number of steps to cache for replay endpoint. Will only save
events from the last N steps for the replay endpoint.
"""
hwm: int = 100_000
"""The zmq high water mark for the event publisher. After queueing N events,
events will start dropping if the consumer is not keeping up.
"""
max_queue_size: int = 100_000
"""The maximum number of events to queue while waiting for publishing.
"""
topic: str = ""
"""The topic to use for the event publisher. Consumers can subscribe to
this topic to receive events.
"""
@classmethod
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
"""Parse the CLI value for the event publisher config."""
return KVEventsConfig.model_validate_json(cli_value)
class EventPublisherFactory:
_registry: dict[str, Callable[..., EventPublisher]] = {
"null": NullEventPublisher,
"zmq": ZmqEventPublisher,
}
@classmethod
def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
if name in cls._registry:
raise KeyError(f"publisher '{name}' already registered")
cls._registry[name] = ctor
@classmethod
def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:
"""Create publisher from a config mapping."""
if not config:
return NullEventPublisher()
config = KVEventsConfig.from_cli(config)
config_dict = config.model_dump()
kind = config_dict.pop("publisher", "null")
try:
constructor = cls._registry[kind]
except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(attn_dp_rank=attn_dp_rank, **config_dict)

View File

@@ -0,0 +1,6 @@
raise RuntimeError(
"""The 'mini_lb' module has been relocated to the 'sglang_router' package.
We recommend installing 'sglang-router' with Rust support for optimal performance.
If you encounter issues building the router with Rust, set the environment variable
'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
)

View File

@@ -0,0 +1,6 @@
from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,164 @@
import logging
from typing import List, Optional
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
logger = logging.getLogger(__name__)
class MooncakeTransferEngine:
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run SGLang with MooncakeTransferEngine."
) from e
self.engine = TransferEngine()
self.hostname = hostname
self.gpu_id = gpu_id
self.ib_device = ib_device
self.initialize(
hostname=self.hostname,
device_name=self.ib_device,
)
self.session_id = (
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
)
def register(self, ptr, length):
try:
ret_value = self.engine.register_memory(ptr, length)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0:
logger.debug("Mooncake memory registration %s failed.", ptr)
def deregister(self, ptr):
try:
ret_value = self.engine.unregister_memory(ptr)
except Exception:
# Mark deregister as failed
ret_value = -1
if ret_value != 0:
logger.debug("Mooncake memory deregistration %s failed.", ptr)
def batch_register(self, ptrs: List[int], lengths: List[int]) -> int:
"""Batch register multiple memory regions."""
try:
ret_value = self.engine.batch_register_memory(ptrs, lengths)
except Exception:
# Mark batch register as failed
ret_value = -1
if not hasattr(self.engine, "batch_register_memory"):
raise RuntimeError(
"Mooncake's batch register requires a newer version of mooncake-transfer-engine. "
"Please upgrade Mooncake."
)
if ret_value != 0:
logger.debug("Mooncake batch memory registration failed.")
return ret_value
def batch_deregister(self, ptrs: List[int]) -> int:
"""Batch deregister multiple memory regions."""
try:
ret_value = self.engine.batch_unregister_memory(ptrs)
except Exception:
# Mark batch deregister as failed
ret_value = -1
if ret_value != 0:
logger.debug("Mooncake batch memory deregistration failed.")
return ret_value
def initialize(
self,
hostname: str,
device_name: Optional[str],
) -> None:
"""Initialize the mooncake instance."""
if get_bool_env_var("ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE", "false"):
hostname += f":{get_free_port()}:npu_{self.gpu_id}"
ret_value = self.engine.initialize(
hostname,
"P2PHANDSHAKE",
"ascend",
device_name if device_name is not None else "",
)
else:
ret_value = self.engine.initialize(
hostname,
"P2PHANDSHAKE",
"rdma",
device_name if device_name is not None else "",
)
if ret_value != 0:
logger.error("Mooncake Transfer Engine initialization failed.")
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
def transfer_sync(
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
) -> int:
"""Synchronously transfer data to the specified address."""
try:
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
# later: based on the cached queue pair to send data
ret = self.engine.transfer_sync_write(
session_id, buffer, peer_buffer_address, length
)
except Exception:
# Mark transfer request as failed
ret = -1
if ret < 0:
# Do not raise an exception here, since some transfer requests fail should be accepted and the execution thread should not be stopped.
logger.debug(
"Failed to transfer data from %s to %s - %s.",
buffer,
session_id,
peer_buffer_address,
)
return ret
def batch_transfer_sync(
self,
session_id: str,
buffers: List[int],
peer_buffer_addresses: List[int],
lengths: List[int],
) -> int:
"""Synchronously transfer data to the specified addresses in batches."""
try:
ret = self.engine.batch_transfer_sync_write(
session_id, buffers, peer_buffer_addresses, lengths
)
except Exception:
ret = -1
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
if not hasattr(self.engine, "batch_transfer_sync_write"):
raise RuntimeError(
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
)
if ret < 0:
logger.debug(
"Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
buffers,
session_id,
peer_buffer_addresses,
)
return ret
def get_session_id(self):
return self.session_id

View File

@@ -0,0 +1,6 @@
from sglang.srt.disaggregation.nixl.conn import (
NixlKVBootstrapServer,
NixlKVManager,
NixlKVReceiver,
NixlKVSender,
)

View File

@@ -0,0 +1,696 @@
from __future__ import annotations
import asyncio
import dataclasses
import logging
import queue
import socket
import struct
import threading
import uuid
from collections import defaultdict
from functools import cache
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
)
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_local_ip_auto,
is_valid_ipv6_address,
)
logger = logging.getLogger(__name__)
GUARD = "NixlMsgGuard".encode("ascii")
@dataclasses.dataclass
class TransferInfo:
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
room: int
endpoint: str
dst_port: int
agent_name: str
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
required_dst_info_num: int
def is_dummy(self):
return self.dst_kv_indices.size == 0
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_name=msg[3].decode("ascii"),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
dst_aux_index=int(msg[5].decode("ascii")),
required_dst_info_num=int(msg[6].decode("ascii")),
)
@dataclasses.dataclass
class KVArgsRegisterInfo:
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
room: str
endpoint: str
dst_port: int
agent_name: str
agent_metadata: bytes
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]
gpu_id: int
decode_tp_size: int
decode_tp_rank: int
dst_kv_item_len: int
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=str(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_name=msg[3].decode("ascii"),
agent_metadata=msg[4],
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
gpu_id=int(msg[7].decode("ascii")),
decode_tp_size=int(msg[8].decode("ascii")),
decode_tp_rank=int(msg[9].decode("ascii")),
dst_kv_item_len=int(msg[10].decode("ascii")),
)
@dataclasses.dataclass
class TransferStatus:
"""Used by KV Receiver to know when a transfer is done."""
# KV chunk IDs that have been received.
received_kvs: Set[int] = dataclasses.field(default_factory=set)
# Number of kv chunks to expect, will know this after last chunk is received.
num_kvs_expected: Optional[int] = None
# Whether aux data has been received.
received_aux: bool = False
def is_done(self):
if self.num_kvs_expected is None:
return False
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
class NixlKVManager(CommonKVManager):
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
):
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
try:
from nixl._api import nixl_agent
except ImportError as e:
raise ImportError(
"Please install NIXL by following the instructions at "
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
"to run SGLang with NixlTransferEngine."
) from e
self.agent = nixl_agent(str(uuid.uuid4()))
self.local_ip = get_local_ip_auto()
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.request_status: Dict[int, KVPoll] = {}
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self._start_bootstrap_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
TransferStatus
)
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def check_status(self, bootstrap_room: int):
return self.request_status[bootstrap_room]
def update_status(self, bootstrap_room: int, status: KVPoll):
if bootstrap_room not in self.request_status:
self.request_status[bootstrap_room] = status
else:
# NOTE: The prefill engine could recv bootstrapping first
self.request_status[bootstrap_room] = max(
self.request_status[bootstrap_room], status
)
def register_buffer_to_engine(self):
kv_addrs = []
for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
):
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
if not self.kv_descs:
raise Exception("NIXL memory registration failed for kv tensors")
aux_addrs = []
for aux_data_ptr, aux_data_len in zip(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
):
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors")
def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
agent_name = decode_kv_args.agent_name
if agent_name in self.decode_kv_args_table:
logger.info(f"Peer {agent_name} was already registered, ignoring.")
return
self.decode_kv_args_table[agent_name] = decode_kv_args
self.agent.add_remote_agent(decode_kv_args.agent_metadata)
def send_kvcache(
self,
peer_name: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
dst_gpu_id: int,
notif: str,
):
# group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
# Make descs
num_layers = len(self.kv_args.kv_data_ptrs)
src_addrs = []
dst_addrs = []
for layer_id in range(num_layers):
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
dst_ptr = dst_kv_ptrs[layer_id]
item_len = self.kv_args.kv_item_lens[layer_id]
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index)
src_addrs.append((src_addr, length, self.kv_args.gpu_id))
dst_addrs.append((dst_addr, length, dst_gpu_id))
logger.debug(
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
)
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
# Transfer data
xfer_handle = self.agent.initialize_xfer(
"WRITE",
src_descs,
dst_descs,
peer_name,
notif.encode("ascii"), # type: ignore
)
if not xfer_handle:
raise Exception("KVSender failed to create transfer")
state = self.agent.transfer(xfer_handle)
if state == "ERR":
raise Exception("KVSender failed to post transfer")
return xfer_handle
def send_kvcache_slice(
self,
peer_name: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
dst_gpu_id: int,
notif: str,
prefill_tp_size: int,
decode_tp_size: int,
decode_tp_rank: int,
dst_kv_item_len: int,
):
# Get configuration from kv_args
local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
num_kv_heads = self.kv_args.kv_head_num
# Calculate head distribution
src_heads_per_rank = num_kv_heads
dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
src_kv_item_len = self.kv_args.kv_item_lens[0]
page_size = self.kv_args.page_size
bytes_per_head_slice_to_send = (
dst_kv_item_len // page_size // dst_heads_per_rank
)
# Determine which heads to send
if prefill_tp_size > decode_tp_size:
# Multiple prefill ranks to one decode rank
src_head_start_offset = 0
num_heads_to_send = src_heads_per_rank
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
else:
# Send KVCache from 1 prefill instance to multiple decode instances
src_head_start_offset = (
dst_tp_rank_in_group * dst_heads_per_rank
) % src_heads_per_rank
num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0
# Create transfer descriptors
src_addrs = []
dst_addrs = []
bytes_per_token_on_prefill = src_kv_item_len // page_size
bytes_per_token_on_decode = dst_kv_item_len // page_size
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
src_dst_ptr_pairs = [
(
src_k_ptrs[layer_id],
dst_k_ptrs[layer_id],
)
for layer_id in range(len(src_k_ptrs))
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
)
for layer_id in range(len(src_v_ptrs))
]
src_addrs = []
dst_addrs = []
# Calculate strides for a single token slot
bytes_per_token_on_prefill = src_kv_item_len // page_size
bytes_per_token_on_decode = dst_kv_item_len // page_size
for src_ptr, dst_ptr in src_dst_ptr_pairs:
for i in range(len(prefill_kv_indices)):
prefill_page_idx = int(prefill_kv_indices[i])
decode_page_idx = int(dst_kv_indices[i])
# Get the starting addresses for the current src and dst pages
src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
# Iterate through each valid token slot within the current page
for token_slot_in_page in range(page_size):
# Calculate the start address of the current token slot
src_token_slot_start_addr = (
src_page_start_addr
+ token_slot_in_page * bytes_per_token_on_prefill
)
dst_token_slot_start_addr = (
dst_page_start_addr
+ token_slot_in_page * bytes_per_token_on_decode
)
# Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
src_addrs.append(
(
src_slice_addr,
heads_bytes_per_token_to_send,
self.kv_args.gpu_id,
)
)
dst_addrs.append(
(dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
)
# Use NIXL agent for transfer
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
xfer_handle = self.agent.initialize_xfer(
"WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
)
if not xfer_handle:
raise Exception("Failed to create sliced KV transfer")
state = self.agent.transfer(xfer_handle)
if state == "ERR":
raise Exception("Failed to post sliced KV transfer")
return xfer_handle
def send_aux(
self,
peer_name: str,
prefill_aux_index: int,
dst_aux_ptrs: list[int],
dst_aux_index: int,
notif: str,
):
# Make descs
aux_item_len = self.kv_args.aux_item_lens[0]
prefill_aux_addr = (
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
)
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
# Transfer data
xfer_handle = self.agent.initialize_xfer(
"WRITE",
src_descs,
dst_descs,
peer_name,
notif.encode("ascii"), # type: ignore
)
if not xfer_handle:
raise Exception("KVSender failed to create transfer")
state = self.agent.transfer(xfer_handle)
if state == "ERR":
raise Exception("KVSender failed to post transfer")
return xfer_handle
def add_transfer_request(
self,
bootstrap_room: int,
kv_indices: npt.NDArray[np.int32],
index_slice: slice,
is_last: bool,
chunk_id: int,
aux_index: Optional[int] = None,
):
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
reqs_to_be_processed = self.transfer_infos[bootstrap_room].values()
handles = []
for req in reqs_to_be_processed:
assert bootstrap_room == req.room
if req.is_dummy():
continue
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
assert len(chunked_dst_kv_indice) == len(kv_indices)
assert req.agent_name in self.decode_kv_args_table
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
if decode_tp_size == self.tp_size:
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
)
else:
kv_xfer_handle = self.send_kvcache_slice(
req.agent_name,
kv_indices,
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
].decode_tp_rank,
dst_kv_item_len=self.decode_kv_args_table[
req.agent_name
].dst_kv_item_len,
)
handles.append(kv_xfer_handle)
# Only the last chunk we need to send the aux data.
if is_last:
assert aux_index is not None
aux_xfer_handle = self.send_aux(
req.agent_name,
aux_index,
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
req.dst_aux_index,
str(req.room) + "_aux",
)
handles.append(aux_xfer_handle)
if is_last:
del self.transfer_infos[bootstrap_room]
return handles
def update_transfer_status(self):
# Process notifications from received transfers.
notif_map = self.agent.get_new_notifs()
for peer_name, messages in notif_map.items():
# We could also check that self.bootstrap_info['agent_name'] matches
# the message sender. But the bootstrap room alone should be
# sufficient to map the status.
for msg in messages:
components = msg.decode("ascii").split("_")
room = int(components[0])
if components[1] == "kv":
chunk_id = int(components[2])
is_last = bool(int(components[3]))
self.transfer_statuses[room].received_kvs.add(chunk_id)
if is_last:
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
elif components[1] == "aux":
self.transfer_statuses[room].received_aux = True
def check_transfer_done(self, room: int):
if room not in self.transfer_statuses:
return False
return self.transfer_statuses[room].is_done()
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def _start_bootstrap_thread(self):
self._bind_server_socket()
def bootstrap_thread():
"""This thread recvs transfer info from the decode engine"""
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
logger.debug(
f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}"
)
assert (
waiting_req_bytes[0] == GUARD
), f"First message should be {GUARD}. Foreign traffic?"
waiting_req_bytes = waiting_req_bytes[1:]
room = waiting_req_bytes[0].decode("ascii")
agent_name = waiting_req_bytes[3].decode("ascii")
if room == "None":
# Register new peer and save KV base pointers.
self._add_remote_peer(
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
logger.debug(f"Register KVArgs from {agent_name} successfully")
continue
room = int(room)
if room not in self.transfer_infos:
self.transfer_infos[room] = {}
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
waiting_req_bytes
)
required_dst_info_num = self.transfer_infos[room][
agent_name
].required_dst_info_num
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
if len(self.transfer_infos[room]) == required_dst_info_num:
logger.debug(f"{room=} is bootstrapped")
self.update_status(room, KVPoll.WaitingForInput)
threading.Thread(target=bootstrap_thread).start()
class NixlKVSender(BaseKVSender):
def __init__(
self,
mgr: NixlKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.xfer_handles = []
self.has_sent = False
self.chunk_id = 0
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send(
self,
kv_indices: npt.NDArray[np.int32],
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
is_last = self.curr_idx == self.num_kv_indices
new_xfer_handles = self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
is_last,
self.chunk_id,
self.aux_index,
)
self.xfer_handles.extend(new_xfer_handles)
self.chunk_id += 1
if is_last:
self.has_sent = True
del self.kv_mgr.request_status[self.bootstrap_room]
def poll(self) -> KVPoll:
if not self.has_sent:
return self.kv_mgr.check_status(self.bootstrap_room)
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
if all([x == "DONE" for x in states]):
return KVPoll.Success # type: ignore
if any([x == "ERR" for x in states]):
raise Exception("KVSender transfer encountered an error.")
return KVPoll.WaitingForInput # type: ignore
def failure_exception(self):
raise Exception("Fake KVSender Exception")
class NixlKVReceiver(CommonKVReceiver):
def __init__(
self,
mgr: NixlKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.started_transfer = False
self.conclude_state = None
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"]
logger.debug(
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
)
with lock:
sock.send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"",
str(aux_index).encode("ascii"),
str(self.required_dst_info_num).encode("ascii"),
]
)
self.started_transfer = True
def poll(self) -> KVPoll:
if self.conclude_state is not None:
return self.conclude_state
if not self.started_transfer:
return KVPoll.WaitingForInput # type: ignore
self.kv_mgr.update_transfer_status()
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
self.conclude_state = KVPoll.Success
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
return KVPoll.Success # type: ignore
return KVPoll.WaitingForInput # type: ignore
def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
with lock:
sock.send_multipart(
[
GUARD,
"None".encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
]
)
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class NixlKVBootstrapServer(CommonKVBootstrapServer):
pass

View File

@@ -0,0 +1,867 @@
"""
Life cycle of a request in the prefill server
1. Bootstrap Queue
a. Initialize a sender for each request
b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
c. Poll senders to check bootstrap state
d. Once bootstrap is complete, move request to Waiting Queue
2. Waiting Queue
a. Use PrefillAdder to pop requests
b. Run forward
c. Add the request to Inflight Queue
3. Inflight Queue
a. Poll (non-blocking) the sender of the request
b. Once the transfer has finished, return the request
"""
from __future__ import annotations
import logging
import threading
from collections import deque
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Type
import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
is_mla_backend,
kv_to_page_indices,
kv_to_page_num,
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import (
DynamicGradMode,
broadcast_pyobj,
point_to_point_pyobj,
require_mlp_sync,
)
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
from sglang.srt.mem_cache.memory_pool import KVCache
logger = logging.getLogger(__name__)
class PrefillBootstrapQueue:
"""
Store the requests in bootstrapping
"""
def __init__(
self,
token_to_kv_pool: KVCache,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: MetadataBuffers,
tp_rank: int,
tp_size: int,
gpu_id: int,
bootstrap_port: int,
gloo_group: ProcessGroup,
max_total_num_tokens: int,
decode_tp_size: int,
decode_dp_size: int,
scheduler: Scheduler,
pp_rank: int,
pp_size: int,
transfer_backend: TransferBackend,
):
self.token_to_kv_pool = token_to_kv_pool
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.tp_size = tp_size
self.decode_tp_size = decode_tp_size
self.decode_dp_size = decode_dp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.gpu_id = gpu_id
self.bootstrap_port = bootstrap_port
self.queue: List[Req] = []
self.gloo_group = gloo_group
self.max_total_num_tokens = max_total_num_tokens
self.scheduler = scheduler
self.transfer_backend = transfer_backend
self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> BaseKVManager:
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class()
kv_args.engine_rank = self.tp_rank
kv_args.pp_rank = self.pp_rank
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
kv_args.prefill_pp_size = self.pp_size
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs += draft_kv_data_ptrs
kv_data_lens += draft_kv_data_lens
kv_item_lens += draft_kv_item_lens
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
if not self.is_mla_backend:
kv_args.kv_head_num = self.token_to_kv_pool.head_num
kv_args.page_size = self.token_to_kv_pool.page_size
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class: Type[BaseKVManager] = get_kv_class(
self.transfer_backend, KVClassType.MANAGER
)
kv_manager: BaseKVManager = kv_manager_class(
kv_args,
DisaggregationMode.PREFILL,
self.scheduler.server_args,
self.is_mla_backend,
)
return kv_manager
def add(self, req: Req, num_kv_heads: int) -> None:
if self._check_if_req_exceed_kv_capacity(req):
return
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
else:
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
dest_tp_ranks = [self.tp_rank]
req.disagg_kv_sender = kv_sender_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
dest_tp_ranks=dest_tp_ranks,
pp_rank=self.pp_rank,
)
self._process_req(req)
self.queue.append(req)
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
for req in reqs:
self.add(req, num_kv_heads)
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message)
prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
self.scheduler.stream_output([req], req.return_logprob)
return True
return False
def _process_req(self, req: Req) -> None:
"""
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
"""
req.sampling_params.max_new_tokens = 1
def pop_bootstrapped(
self,
return_failed_reqs: bool = False,
rids_to_check: Optional[List[str]] = None,
) -> List[Req]:
"""
pop the reqs which has finished bootstrapping
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
bootstrapped_reqs = []
failed_reqs = []
indices_to_remove = set()
if len(self.queue) == 0:
if return_failed_reqs is False:
return []
else:
return [], []
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.queue], self.gloo_group
)
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None:
# if req not in reqs_info_to_check, skip
if req.rid not in rids_to_check:
continue
# Either waiting for input or failed
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
if poll == KVPoll.Bootstrapping:
continue
elif poll == KVPoll.Failed:
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
try:
req.disagg_kv_sender.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.error(error_message)
prepare_abort(
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
self.scheduler.stream_output([req], req.return_logprob)
indices_to_remove.add(i)
failed_reqs.append(req)
if self.scheduler.enable_metrics:
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
continue
# KV.WaitingForInput - init here
num_kv_indices = len(req.origin_input_ids)
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
break
req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert req.metadata_buffer_index is not None
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
bootstrapped_reqs.append(req)
indices_to_remove.add(i)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
if return_failed_reqs is False:
return bootstrapped_reqs
else:
return bootstrapped_reqs, failed_reqs
class SchedulerDisaggregationPrefillMixin:
"""
Mixin for Scheduler to handle disaggregation prefill
"""
@torch.no_grad()
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.process_batch_result_disagg_prefill(batch, result)
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.self_check_during_idle()
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
self.result_queue = deque()
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.self_check_during_idle()
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
def process_batch_result_disagg_prefill(
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Adapted from process_batch_result_prefill
"""
(
logits_output,
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
)
logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
if self.enable_overlap:
# wait
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
launch_done
)
else:
next_token_ids = result.next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0
for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True)
):
req: Req
if req.is_chunked <= 0:
# There is no output_ids for prefill
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.disagg_prefill_inflight_queue.append(req)
if (
logits_output is not None
and logits_output.hidden_states is not None
):
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
req.hidden_states_tensor = (
logits_output.hidden_states[last_hidden_index].cpu().clone()
)
hidden_state_offset += extend_input_len_per_req[i]
else:
req.hidden_states_tensor = None
if req.return_logprob:
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_logprob_return_values(
i,
req,
logprob_pt,
next_token_ids,
num_input_logprobs,
logits_output,
)
logprob_pt += num_input_logprobs
self.send_kv_chunk(req, last_chunk=True)
if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue.
try:
req.grammar.accept_token(next_token_id)
except ValueError as e:
# Grammar accept_token can raise ValueError if the token is not in the grammar.
# This can happen if the grammar is not set correctly or the token is invalid.
error_message = f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
self.tree_cache.cache_finished_req(req)
prepare_abort(
req,
error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
req.grammar.finished = req.finished()
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
if req.return_logprob:
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len:
# Update input logprobs.
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_input_logprob_return_values(
i,
req,
logits_output,
logprob_pt,
num_input_logprobs,
last_prefill_chunk=False,
)
logprob_pt += num_input_logprobs
if self.enable_overlap:
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
def process_disagg_prefill_inflight_queue(
self: Scheduler, rids_to_check: Optional[List[str]] = None
) -> List[Req]:
"""
Poll the requests in the middle of transfer. If done, return the request.
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
if len(self.disagg_prefill_inflight_queue) == 0:
return []
done_reqs = []
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.attn_tp_cpu_group,
)
undone_reqs: List[Req] = []
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
if rids_to_check is not None:
if req.rid not in rids_to_check:
undone_reqs.append(req)
continue
assert poll == KVPoll.Success or poll == KVPoll.Failed
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
undone_reqs.append(req)
elif poll == KVPoll.Success: # transfer done
self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0)
# FIXME: clean up req's data in transfer engine
if hasattr(req.disagg_kv_sender, "clear"):
req.disagg_kv_sender.clear()
done_reqs.append(req)
elif poll == KVPoll.Failed:
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
try:
req.disagg_kv_sender.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.warning(error_message)
self.tree_cache.cache_finished_req(req) # unlock the tree
prepare_abort(
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
done_reqs.append(req)
if self.enable_metrics:
self.metrics_collector.increment_transfer_failed_reqs()
else:
assert False, f"Unexpected polling state {poll=}"
# Stream requests which have finished transfer
self.stream_output(
done_reqs,
any(req.return_logprob for req in done_reqs),
None,
)
for req in done_reqs:
req: Req
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1
self.disagg_prefill_inflight_queue = undone_reqs
return done_reqs
def get_transferred_rids(self: Scheduler) -> List[str]:
"""
Used by PP, get the transferred rids but **do not pop**
"""
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.tp_worker.get_tp_group().cpu_group,
)
transferred_rids: List[str] = []
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
if poll == KVPoll.Success or poll == KVPoll.Failed:
transferred_rids.append(req.rid)
return transferred_rids
def process_prefill_chunk(self: Scheduler) -> None:
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
if self.enable_overlap:
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min(
len(self.chunked_req.fill_ids),
len(self.chunked_req.origin_input_ids),
)
else:
self.send_kv_chunk(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.running_batch.batch_is_full = False
def send_kv_chunk(
self: Scheduler,
req: Req,
last_chunk: bool = False,
end_idx: Optional[int] = None,
) -> None:
"""
Send a prefilled chunk to the decode server
"""
page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx
end_idx = (
end_idx
if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids))
)
if not last_chunk:
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size
kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
)
req.start_send_idx = end_idx
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
req.disagg_kv_sender.send(page_indices)
# PP
@DynamicGradMode()
def event_loop_pp_disagg_prefill(self: Scheduler):
"""
An event loop for the prefill server in pipeline parallelism.
Rules:
1. Each stage runs in the same order and is notified by the previous stage.
2. Each send/recv operation is blocking and matched by the neighboring stage.
Regular Schedule:
====================================================================
Stage i | Stage i+1
send ith req | recv ith req
send ith proxy | recv ith proxy
send prev (i+1)th carry | recv prev (i+1)th carry
====================================================================
Prefill Server Schedule:
====================================================================
Stage i | Stage i+1
send ith req | recv ith req
send ith bootstrap req | recv ith bootstrap req
send ith transferred req | recv ith transferred req
send ith proxy | recv ith proxy
send prev (i+1)th carry | recv prev (i+1)th carry
send prev (i+1)th release req | recv prev (i+1)th release req
====================================================================
There are two additional elements compared to the regular schedule:
1. Bootstrap Requests:
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
2. Transferred Requests + Release Requests:
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
"""
from sglang.srt.managers.scheduler import GenerationBatchResult
mbs = [None] * self.pp_size
last_mbs = [None] * self.pp_size
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
# Either success or failed
bootstrapped_rids: List[str] = []
transferred_rids: List[str] = []
release_rids: Optional[List[str]] = None
# transferred microbatch
tmbs = [None] * self.pp_size
ENABLE_RELEASE = True # For debug
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
self.running_batch = self.running_mbs[mb_id]
self.last_batch = last_mbs[mb_id]
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
if self.pp_group.is_first_rank:
# First rank, pop the bootstrap reqs from the bootstrap queue
bootstrapped_reqs, failed_reqs = (
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
return_failed_reqs=True
)
)
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
req.rid for req in failed_reqs
]
self.waiting_queue.extend(bootstrapped_reqs)
else:
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
bootstrapped_reqs = (
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
rids_to_check=bootstrapped_rids
)
)
self.waiting_queue.extend(bootstrapped_reqs)
if self.pp_group.is_first_rank:
transferred_rids = self.get_transferred_rids()
# if other ranks,
else:
# 1. recv previous stage's transferred reqs info
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
# 2. get the current stage's transferred reqs info
curr_transferred_rids = self.get_transferred_rids()
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
transferred_rids = list(
set(prev_transferred_rids) & set(curr_transferred_rids)
)
tmbs[mb_id] = transferred_rids
self.process_prefill_chunk()
mbs[mb_id] = self.get_new_batch_prefill()
self.running_mbs[mb_id] = self.running_batch
self.cur_batch = mbs[mb_id]
if self.cur_batch:
server_is_idle = False
result = self.run_batch(self.cur_batch)
# send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if ENABLE_RELEASE:
if self.pp_group.is_last_rank:
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
release_rids = transferred_rids
# send to the first rank
self.send_pyobj_to_next_stage(release_rids)
# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = (mb_id + 1) % self.pp_size
next_pp_outputs = None
next_release_rids = None
if mbs[next_mb_id] is not None:
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
output_result = GenerationBatchResult(
logits_output=None,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result_disagg_prefill(
mbs[next_mb_id], output_result
)
last_mbs[next_mb_id] = mbs[next_mb_id]
if ENABLE_RELEASE:
if tmbs[next_mb_id] is not None:
# recv consensus rids from the previous rank
next_release_rids = self.recv_pyobj_from_prev_stage()
self.process_disagg_prefill_inflight_queue(next_release_rids)
# carry the outputs to the next stage
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if ENABLE_RELEASE:
if release_rids is not None:
self.send_pyobj_to_next_stage(release_rids)
if not self.pp_group.is_last_rank:
# send out reqs to the next stage
self.send_pyobj_to_next_stage(recv_reqs)
self.send_pyobj_to_next_stage(bootstrapped_rids)
self.send_pyobj_to_next_stage(transferred_rids)
# send out proxy tensors to the next stage
if self.cur_batch:
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
all_gather_group=self.attn_tp_group,
)
pp_outputs = next_pp_outputs
release_rids = next_release_rids
self.running_batch.batch_is_full = False
if not ENABLE_RELEASE:
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
# When the server is idle, self-check and re-init some states
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
def send_pyobj_to_next_stage(self, data):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
point_to_point_pyobj(
data,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
)
def recv_pyobj_from_prev_stage(self):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
data = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
else:
data = None
if self.tp_size != 1:
data = broadcast_pyobj(
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
)
return data

View File

@@ -0,0 +1,329 @@
from __future__ import annotations
import os
import random
from collections import deque
from contextlib import nullcontext
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Type, Union
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.utils import is_npu
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
#########################
# Constants & Enums
#########################
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
class DisaggregationMode(Enum):
NULL = "null"
PREFILL = "prefill"
DECODE = "decode"
#########################
# Synchronization
#########################
# env var for testing failure, convert to float explicitly
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
def poll_and_all_reduce(pollers, gloo_group):
# at a certain prob, the poll is failed to simulate failure
if FAILURE_PROB > 0:
from sglang.srt.disaggregation.base import KVPoll
polls = [
int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
for poller in pollers
]
else:
polls = [int(poller.poll()) for poller in pollers]
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
return tensor_to_reduce.tolist()
#########################
# Metadata Buffers
#########################
class ReqToMetadataIdxAllocator:
"""A memory pool that maps a request to its first output token location."""
def __init__(
self,
size: int,
):
self.size = size
self.free_slots = deque(list(range(size)))
def available_size(self):
return len(self.free_slots)
def alloc(self) -> Optional[int]:
if len(self.free_slots) == 0:
return None
return self.free_slots.popleft()
def free(self, free_index: int):
self.free_slots.append(free_index)
class MetadataBuffers:
def __init__(
self,
size: int,
hidden_size: int,
dtype: torch.dtype,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
self.custom_mem_pool = custom_mem_pool
device = "cpu"
if is_npu():
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
device = "npu"
elif self.custom_mem_pool:
# TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
device = "cpu"
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
# TODO: abort top_logprobs_num > 128 in PD
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_token_logprobs_idx = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_top_logprobs_val = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.float32, device=device
)
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device
)
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
self.output_hidden_states.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
self.output_hidden_states.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
self.output_hidden_states[0].nbytes,
]
return ptrs, data_lens, item_lens
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
self.output_hidden_states[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
req.output_token_logprobs_val[0]
)
if req.output_token_logprobs_idx: # not none or empty list
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
req.output_token_logprobs_idx[0]
)
if req.output_top_logprobs_val: # not none or empty list
self.output_top_logprobs_val[req.metadata_buffer_index][
: len(req.output_top_logprobs_val[0])
] = torch.tensor(
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
)
if req.output_top_logprobs_idx: # not none or empty list
self.output_top_logprobs_idx[req.metadata_buffer_index][
: len(req.output_top_logprobs_idx[0])
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
# for PD + spec decode
if req.hidden_states_tensor is not None:
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
#########################
# Transfer Backend
#########################
class TransferBackend(Enum):
MOONCAKE = "mooncake"
NIXL = "nixl"
ASCEND = "ascend"
FAKE = "fake"
class KVClassType(Enum):
KVARGS = "kvargs"
MANAGER = "manager"
SENDER = "sender"
RECEIVER = "receiver"
BOOTSTRAP_SERVER = "bootstrap_server"
def get_kv_class(
transfer_backend: TransferBackend, class_type: KVClassType
) -> Optional[Type]:
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
if transfer_backend == TransferBackend.MOONCAKE:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.mooncake import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: MooncakeKVManager,
KVClassType.SENDER: MooncakeKVSender,
KVClassType.RECEIVER: (MooncakeKVReceiver),
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.ASCEND:
from sglang.srt.disaggregation.ascend import (
AscendKVBootstrapServer,
AscendKVManager,
AscendKVReceiver,
AscendKVSender,
)
from sglang.srt.disaggregation.base import KVArgs
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: AscendKVManager,
KVClassType.SENDER: AscendKVSender,
KVClassType.RECEIVER: (AscendKVReceiver),
KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import (
NixlKVBootstrapServer,
NixlKVManager,
NixlKVReceiver,
NixlKVSender,
)
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: NixlKVManager,
KVClassType.SENDER: NixlKVSender,
KVClassType.RECEIVER: (NixlKVReceiver),
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.SENDER: FakeKVSender,
KVClassType.RECEIVER: (FakeKVReceiver),
}
return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
#########################
# KV Pages
#########################
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
# 1. The page is guaranteed to be full except the last page.
# 2. page index = kv_index // page_size
# The return vector is kv_indices[::page_size] // page_size
if page_size == 1: # shortcut
return kv_indices
return kv_indices[::page_size] // page_size
def kv_to_page_num(num_kv_indices: int, page_size: int):
# ceil(num_kv_indices / page_size)
return (num_kv_indices + page_size - 1) // page_size
#########################
# Misc
#########################
def is_mla_backend(target_kv_pool) -> bool:
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
return isinstance(target_kv_pool, MLATokenToKVPool)
def prepare_abort(req: Req, error_message: str, status_code=None):
from sglang.srt.managers.schedule_batch import FINISH_ABORT
# populate finish metadata and stream output
req.finished_reason = FINISH_ABORT(error_message, status_code)
if req.return_logprob:
req.input_token_logprobs_val = []
req.input_token_logprobs_idx = []
req.input_top_logprobs_val = []
req.input_top_logprobs_idx = []
req.input_token_ids_logprobs_val = []
req.input_token_ids_logprobs_idx = []

View File

@@ -0,0 +1,3 @@
from sglang.srt.distributed.communication_op import *
from sglang.srt.distributed.parallel_state import *
from sglang.srt.distributed.utils import *

View File

@@ -0,0 +1,35 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed
from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)
def tensor_model_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)

View File

@@ -0,0 +1,183 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import ctypes
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
logger = logging.getLogger(__name__)
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found = True
break
if not found:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = line.index("/")
path = line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(
lib_name
), f"Unexpected filename: {filename} for library {lib_name}"
return path
class CudaRTLibrary:
exported_functions = [
# cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function(
"cudaMalloc",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
),
# cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function(
"cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function(
"cudaMemcpy",
cudaError_t,
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function(
"cudaIpcGetMemHandle",
cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function(
"cudaIpcOpenMemHandle",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
assert so_file is not None, "libcudart is not loaded in the current process"
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")
def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(
self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(
self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
)
return handle
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(
self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
)
)
return devPtr

View File

@@ -0,0 +1,421 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
import ctypes
import logging
import os
from contextlib import contextmanager
from typing import Any, List, Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt import _custom_ops as ops
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check,
is_full_nvlink,
is_weak_contiguous,
)
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
try:
if ops.use_vllm_custom_allreduce and not _is_hip:
# Use vLLM custom allreduce
ops.meta_size()
else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel
custom_ar = True
except Exception:
# For CPUs
custom_ar = False
logger = logging.getLogger(__name__)
def _can_p2p(rank: int, world_size: int) -> bool:
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
SGLANG_SKIP_P2P_CHECK = os.getenv("SGLANG_SKIP_P2P_CHECK", "0") == "1"
for i in range(world_size):
if i == rank:
continue
if SGLANG_SKIP_P2P_CHECK:
logger.info("Skipping P2P check and trusting the driver's P2P report.")
return torch.cuda.can_device_access_peer(rank, i)
if not gpu_p2p_access_check(rank, i):
return False
return True
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024
if _is_hip:
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE = 2 * 8192 * 1024
# max_size: max supported allreduce size
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=_MAX_CAR_SIZE,
) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "CustomAllreduce should be attached to a non-NCCL group."
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size,
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
if _is_cuda or _is_hip:
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not _is_hip and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = full_nvlink
if not _is_hip:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(
ops.meta_size() + max_size, group=group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self._ptr = ops.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
else:
# meta data buffers need to be "uncached" for signal on MI200
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
handle = ops.get_meta_buffer_ipc_handle(self.meta)
shard_data = (
bytes(handle), # ipc handle to base ptr
0, # offset of base ptr
)
handles, offsets = self._gather_ipc_meta(shard_data)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self._ptr = ops.init_custom_ar(
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
)
self.register_buffer(self.buffer)
self.disabled = False
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def _get_ipc_meta(self, inp: torch.Tensor):
# _share_cuda_() doesn't accept meta buffer not allocated from
# PyTorch cache allocator, use direct HIP call to get IPC handle
handle = ops.get_meta_buffer_ipc_handle(inp)
shard_data = (
bytes(handle), # ipc handle to base ptr
0, # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]
all_data[self.rank][0] = shard_data
ranks = dist.get_process_group_ranks(group=self.group)
ranks.sort()
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0][0]) # type: ignore
offsets.append(all_data[i][0][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
if _is_hip:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
ops.register_graph_buffers(self._ptr, handles, offsets)
else:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [
[None, None] for _ in range(dist.get_world_size(group=self.group))
]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if not _is_hip:
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
if _is_hip:
if self.full_nvlink:
return inp_size < self.max_size
return False
return False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def all_reduce(
self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
registered: bool = False,
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
if _is_hip:
return self.all_reduce_reg(input)
else:
return self.all_reduce(input, registered=True)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.zeros_like(input)
else:
if _is_hip:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return self.all_reduce_unreg(input)
else:
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
if _is_cuda:
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
self._ptr = 0
def __del__(self):
self.close()

View File

@@ -0,0 +1,386 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
import ctypes
import json
import logging
import os
import pickle
import subprocess
import sys
import tempfile
from functools import wraps
from itertools import product
from typing import Callable, Dict, List, Optional, Sequence, TypeVar
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from typing_extensions import ParamSpec
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda:
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if _is_hip:
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import amdsmi with %r", e)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
logger.warning(
"Overwriting environment variable %s " "from '%s' to '%s'",
k,
os.environ[k],
v,
)
os.environ[k] = v
def producer(
batch_src: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
):
if cuda_visible_devices is not None:
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for i in batch_src:
lib.cudaSetDevice(i)
pointer = lib.cudaMalloc(1024)
lib.cudaMemset(pointer, 1, 1024)
lib.cudaDeviceSynchronize()
handle = lib.cudaIpcGetMemHandle(pointer)
producer_queue.put(handle)
open_success = consumer_queue.get()
if open_success:
# use two queues to simulate barrier
producer_queue.put(0)
consumer_queue.get()
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def consumer(
batch_tgt: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
):
if cuda_visible_devices is not None:
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for j in batch_tgt:
lib.cudaSetDevice(j)
handle = producer_queue.get()
open_success = False
try:
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
open_success = True
except RuntimeError:
# cannot error out here, because the producer process
# is still waiting for the response.
pass
consumer_queue.put(open_success)
if open_success:
# modify the memory
lib.cudaMemset(pointer, 2, 1024)
lib.cudaDeviceSynchronize()
# use two queues to simulate barrier
producer_queue.get()
consumer_queue.put(0)
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def can_actually_p2p(
batch_src: Sequence[int],
batch_tgt: Sequence[int],
) -> Sequence[bool]:
"""
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
Therefore, we have to perform a real P2P access to check if it is actually
possible.
Note on p2p and cuda IPC:
Usually, one process uses one GPU:
GPU src --> cuda context src --> tensor src --> process src
We need to combine p2p and cuda IPC, so that:
GPU src --> cuda context src --> tensor src --> process src
|shared|
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
That is to say, process src creates a tensor in GPU src, passes IPC handle to
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
tensor in process tgt will be reflected in the tensor in process src, because
they are the same memory segment.
It is important to note that process tgt accesses the tensor in GPU tgt, not
GPU src. That's why we need p2p access.
The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
# make sure the processes are spawned
smp = mp.get_context("spawn")
producer_queue = smp.Queue()
consumer_queue = smp.Queue()
result_queue = smp.Queue()
p_src = smp.Process(
target=producer,
args=(
batch_src,
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices,
),
)
p_tgt = smp.Process(
target=consumer,
args=(
batch_tgt,
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices,
),
)
p_src.start()
p_tgt.start()
p_src.join()
p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
b = result_queue.get()
if a != b:
logger.warning(
"Two processes do not agree on the P2P access"
" status on %d -> %d, treat as disabled.",
src,
tgt,
)
result.append(False)
else:
result.append(a)
return result
# why do we need this cache?
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
# if we test it every time, it will be very slow, because we need to create
# N * N * 2 processes, where N is the world size. This is very slow.
# to reduce the time, we use a cache file to store the p2p access status.
# the cache file is generated by the master process if it does not exist.
# then all the processes can read the cache file to check the p2p access status.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
"""Check if GPU src can access GPU tgt."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count()
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
# VLLM_CACHE_ROOT -> SGLANG_CACHE_ROOT
# "~/.cache/vllm" -> "~/.cache/sglang"
SGLANG_CACHE_ROOT = os.path.expanduser("~/.cache/sglang")
path = os.path.join(
SGLANG_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)
os.makedirs(os.path.dirname(path), exist_ok=True)
from sglang.srt.distributed.parallel_state import get_world_group
if (not is_distributed or get_world_group().local_rank == 0) and (
not os.path.exists(path)
):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path)
cache: Dict[str, bool] = {}
ids = list(range(num_dev))
# batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids)))
# NOTE: we use `subprocess` rather than `multiprocessing` here
# because the caller might not have `if __name__ == "__main__":`,
# in that case we cannot use spawn method in multiprocessing.
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with tempfile.NamedTemporaryFile() as output_file:
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
returned = subprocess.run(
[sys.executable, __file__], input=input_bytes, capture_output=True
)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}"
) from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
get_world_group().barrier()
logger.info("reading GPU P2P access cache from %s", path)
with open(path) as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if _is_hip:
try:
amdsmi_init()
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()
else:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if _is_hip:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False
return True
else:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
__all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__":
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))

View File

@@ -0,0 +1,49 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.utils import is_hpu
if is_hpu():
import habana_frameworks.torch as htorch # noqa: F401
class HpuCommunicator:
def __init__(self, group: ProcessGroup):
if not is_hpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
dist.all_reduce(x, group=self.group)
return x
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += x.dim()
input_size = x.size()
# Allocate output tensor.
output_tensor = torch.empty(
(world_size,) + input_size, dtype=x.dtype, device=x.device
)
# All-gather.
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
)
return output_tensor

View File

@@ -0,0 +1,39 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.utils import is_npu
class NpuCommunicator:
def __init__(self, group: ProcessGroup):
if not is_npu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += x.dim()
input_size = x.size()
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
# Reshape
output_tensor = output_tensor.reshape((world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
)
return output_tensor

View File

@@ -0,0 +1,315 @@
import bisect
import logging
import math
import os
from contextlib import contextmanager
from enum import IntEnum
from typing import Any, Callable, List, Optional, TypeVar, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
mscclpp_is_available = False
if _is_hip:
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
mscclpp_is_available = False
if _is_cuda:
try:
import sgl_kernel
mscclpp_is_available = True
except:
mscclpp_is_available = False
class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
MSCCL1SHOT2NODELL = 2
def mscclpp_is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
def mscclpp_convert_to_bytes(size_str):
"""
Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB")
into the equivalent number of bytes using binary units.
Args:
size_str (str): A string representing size with unit (KB, MB, GB).
Returns:
int: Number of bytes.
"""
size_str = size_str.strip().lower()
if not size_str:
raise ValueError("Empty input string")
# Extract numeric part and unit
for i in range(len(size_str)):
if not size_str[i].isdigit() and size_str[i] != ".":
break
num_str = size_str[:i]
unit = size_str[i:].strip()
try:
num = float(num_str)
except ValueError:
raise ValueError(f"Invalid numeric value in '{size_str}'")
# Conversion factors
if unit == "b":
return int(num)
elif unit == "kb":
return int(num * 1024)
elif unit == "mb":
return int(num * 1024 * 1024)
elif unit == "gb":
return int(num * 1024 * 1024 * 1024)
else:
raise ValueError(f"Unsupported unit: {unit}, support B, KB, MB, GB only")
def mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2):
# warmup
for _ in range(warmup_niter):
func()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
dist.barrier()
start_event.record()
for _ in range(test_niter):
func()
end_event.record()
end_event.synchronize()
func_cost_us = start_event.elapsed_time(end_event) / test_niter * 1000
return func_cost_us
class PyMscclppCommunicator:
_SUPPORTED_WORLD_SIZES = [8, 16]
_MAX_BYTES = mscclpp_convert_to_bytes(os.getenv("SGLANG_MSCCLPP_MAX_BYTES", "1MB"))
_SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16]
# max_bytes: max supported mscclpp allreduce size
# in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_bytes=_MAX_BYTES,
) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not mscclpp_is_available:
# disable because of missing mscclpp library
# e.g. in a non-cuda environment
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "CustomAllreduce should be attached to a non-NCCL group."
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize mscclpp for single GPU case.
return
if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES:
logger.warning(
"PyMscclpp is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_mscclpp=True explicitly.",
world_size,
str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES),
)
return
self.ranks = torch.distributed.get_process_group_ranks(group)
self.nranks_per_node = torch.cuda.device_count()
# for now mscclpp with stride in the communicator is not tested
if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1):
logger.warning(
"PyMscclpp is disabled due to an unsupported group %s."
"Please ensure all ranks in the group are consecutive."
"To silence this warning, specify disable_mscclpp=True explicitly.",
str(self.ranks),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
self.max_bytes = max_bytes
self.rank = rank
self.world_size = world_size
if dist.get_rank(group) == 0:
unique_id = [ops.mscclpp_generate_unique_id()]
else:
unique_id = [None]
dist.broadcast_object_list(unique_id, src=self.ranks[0], group=self.group)
self.unique_id = unique_id[0]
self.rank_to_node, self.rank_to_ib = list(range(world_size)), list(
range(world_size)
)
for r in range(world_size):
self.rank_to_node[r] = r // 8
self.rank_to_ib[r] = self.rank % 8
self._context = None
self.context_selection = None
self.msg_size_for_finetune = [
2**i for i in range(10, math.floor(math.log2(self.max_bytes)) + 1)
]
self.msg_size2best_config = {}
if world_size == 8:
self.context_selection = MscclContextSelection.MSCCL1SHOT1NODELL
elif world_size == 16:
self.context_selection = MscclContextSelection.MSCCL1SHOT2NODELL
if not _is_hip:
self.scratch = torch.empty(
self.max_bytes * 8,
dtype=torch.uint8,
device=self.device,
)
self.put_buffer = torch.empty(
self.max_bytes * 8 // self.nranks_per_node,
dtype=torch.uint8,
device=self.device,
)
self._context = ops.mscclpp_init_context(
self.unique_id,
self.rank,
self.world_size,
self.scratch,
self.put_buffer,
self.nranks_per_node,
self.rank_to_node,
self.rank_to_ib,
int(self.context_selection),
)
else:
raise NotImplementedError("HIP Mscclpp is not supported yet.")
self.msg_size2best_config = {}
self.pre_tune_config()
if dist.get_rank(group) == 0:
msg_size2best_config = [self.msg_size2best_config]
else:
msg_size2best_config = [None]
dist.broadcast_object_list(
msg_size2best_config, src=self.ranks[0], group=self.group
)
self.msg_size2best_config = msg_size2best_config[0]
# PyMscclpp is enabled only in cuda graph
self.disabled = True
def pre_tune_config(self, dtype=torch.bfloat16) -> bool:
logger.debug(f"start to pre-tune configs for rank {self.rank}")
nthreads_to_try = [256, 512, 1024]
nblocks_to_try = [21, 42, 84]
inp_randn = torch.ones(
self.msg_size_for_finetune[-1] // dtype.itemsize, dtype=dtype, device="cuda"
)
oup_randn = torch.empty_like(inp_randn)
for msg_size in self.msg_size_for_finetune:
mock_inp, mock_outp = (
inp_randn[: msg_size // dtype.itemsize],
oup_randn[: msg_size // dtype.itemsize],
)
best_config, best_time = None, None
for nthreads in nthreads_to_try:
for nblocks in nblocks_to_try:
cur_cost = mscclpp_bench_time(
lambda: ops.mscclpp_allreduce(
self._context, mock_inp, mock_outp, nthreads, nblocks
)
)
if best_time is None or cur_cost < best_time:
best_config = (nthreads, nblocks)
best_time = cur_cost
self.msg_size2best_config[msg_size] = best_config
if self.rank == 0:
logger.debug(
f"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us"
)
def should_mscclpp_allreduce(
self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM
) -> bool:
if self.disabled or self._context is None:
return False
if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE:
return False
if not mscclpp_is_weak_contiguous(inp):
return False
# only support sum op
if op != ReduceOp.SUM:
return False
if inp.numel() * inp.element_size() > self.max_bytes:
return False
return True
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
self.graph_input_set.add((tensor.dtype, tensor.numel()))
msg_size = tensor.numel() * tensor.itemsize
index = bisect.bisect_left(self.msg_size_for_finetune, msg_size)
msg_size_finetune = self.msg_size_for_finetune[index]
nthreads, nblocks = self.msg_size2best_config[msg_size_finetune]
result = torch.empty_like(tensor)
ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks)
return result
@contextmanager
def change_state(
self,
enable: Optional[bool] = None,
):
if enable is None:
# guess a default value when not specified
enable = self.available
old_disable = self.disabled
self.disabled = not enable
yield
self.disabled = old_disable

View File

@@ -0,0 +1,341 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
import logging
from contextlib import contextmanager
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from sglang.srt.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary,
buffer_type,
cudaStream_t,
ncclComm_t,
ncclDataTypeEnum,
ncclRedOpTypeEnum,
ncclUniqueId,
)
from sglang.srt.distributed.utils import StatelessProcessGroup
logger = logging.getLogger(__name__)
class PyNcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "PyNcclCommunicator should be attached to a non-NCCL group."
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
self.stream = None
return
try:
self.nccl = NCCLLibrary(library_path)
except Exception:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
self.stream = None
return
self.available = True
self.disabled = False
self.nccl_version = self.nccl.ncclGetRawVersion()
if self.rank == 0:
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
if self.rank == 0:
# get the unique id from NCCL
self.unique_id = self.nccl.ncclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = ncclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank
)
self.stream = torch.cuda.Stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
self.stream.synchronize()
del data
# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually
# when we are using CUDA graph.
self.disabled = True
def all_reduce(
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclAllReduce(
buffer_type(tensor.data_ptr()),
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def all_gather(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = self.stream
if sizes is not None:
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset : split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def reduce_scatter(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}"
)
if stream is None:
stream = self.stream
if sizes is not None:
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset : split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()),
chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}"
)
if stream is None:
stream = self.stream
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
recvbuff = buffer_type(tensor.data_ptr())
else:
sendbuff = buffer_type()
recvbuff = buffer_type(tensor.data_ptr())
self.nccl.ncclBroadcast(
sendbuff,
recvbuff,
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
src,
self.comm,
cudaStream_t(stream.cuda_stream),
)
def register_comm_window_raw(self, ptr: int, size: int):
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
def group_start(self):
self.nccl.ncclGroupStart()
def group_end(self):
self.nccl.ncclGroupEnd()
@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
if stream is None:
stream = self.stream
old_disable = self.disabled
old_stream = self.stream
self.stream = stream
self.disabled = not enable
yield
self.disabled = old_disable
self.stream = old_stream

View File

@@ -0,0 +1,133 @@
import tempfile
import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.managers.schedule_batch import global_server_args_dict
nccl_allocator_source = """
#include <nccl.h>
extern "C" {
void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
}
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
ncclResult_t err = ncclMemFree(ptr);
}
}
"""
_allocator = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
def is_symmetric_memory_enabled():
return global_server_args_dict["enable_symm_mem"]
def set_graph_pool_id(graph_pool_id):
global _graph_pool_id
_graph_pool_id = graph_pool_id
def get_nccl_mem_pool():
global _allocator, _mem_pool
if _mem_pool is None:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
torch.utils.cpp_extension.load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
verbose=True,
is_python_module=False,
build_directory=out_dir,
)
_allocator = CUDAPluggableAllocator(
f"{out_dir}/{nccl_allocator_libname}.so",
"nccl_alloc_plug",
"nccl_free_plug",
).allocator()
_mem_pool = torch.cuda.MemPool(_allocator)
return _mem_pool
class use_symmetric_memory:
def __init__(self, group_coordinator: GroupCoordinator):
if not is_symmetric_memory_enabled():
self.group_coordinator = None
self._mem_pool_ctx = None
self.is_graph_capture = None
self.device = None
self.pre_2_8_0 = None
else:
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
def __enter__(self):
if not is_symmetric_memory_enabled():
return self
assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
assert (
self.group_coordinator.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture:
assert (
_graph_pool_id is not None
), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
if self.pre_2_8_0:
torch._C._cuda_endAllocateCurrentStreamToPool(
self.device, _graph_pool_id
)
else:
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
self._mem_pool_ctx.__enter__()
return self
def tag(self, tensor: torch.Tensor):
if not is_symmetric_memory_enabled():
return
tensor.symmetric_memory = True
def __exit__(self, exc_type, exc_val, exc_tb):
if not is_symmetric_memory_enabled():
return
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
if segment["address"] not in _registered_base_addrs:
if segment["stream"] == 0 and self.pre_2_8_0:
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
# See https://github.com/pytorch/pytorch/issues/152861
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
continue
self.group_coordinator.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"]
)
_registered_base_addrs.add(segment["address"])
if self.is_graph_capture:
if self.pre_2_8_0:
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
else:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id
)

Some files were not shown because too many files have changed in this diff Show More