sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
167
python/pyproject.toml
Executable file
167
python/pyproject.toml
Executable file
@@ -0,0 +1,167 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sglang"
|
||||
version = "0.5.2"
|
||||
description = "SGLang is a fast serving framework for large language models and vision language models."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = { file = "LICENSE" }
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
]
|
||||
dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
runtime_common = [
|
||||
"blobfile==3.0.0",
|
||||
"build",
|
||||
"compressed-tensors",
|
||||
"datasets",
|
||||
"einops",
|
||||
"fastapi",
|
||||
"hf_transfer",
|
||||
"huggingface_hub",
|
||||
"interegular",
|
||||
"llguidance>=0.7.11,<0.8.0",
|
||||
"modelscope",
|
||||
"msgspec",
|
||||
"ninja",
|
||||
"openai==1.99.1",
|
||||
"openai-harmony==0.0.4",
|
||||
"orjson",
|
||||
"outlines==0.1.11",
|
||||
"packaging",
|
||||
"partial_json_parser",
|
||||
"pillow",
|
||||
"prometheus-client>=0.20.0",
|
||||
"psutil",
|
||||
"pybase64",
|
||||
"pydantic",
|
||||
"pynvml",
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"sentencepiece",
|
||||
"soundfile==0.13.1",
|
||||
"scipy",
|
||||
"timm==1.0.16",
|
||||
"tiktoken",
|
||||
"torchao==0.9.0",
|
||||
"transformers==4.56.1",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.24",
|
||||
]
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.9.post2",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.3.1",
|
||||
]
|
||||
|
||||
blackwell = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.9.post2",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.3.1",
|
||||
"nvidia-cutlass-dsl==4.1.0",
|
||||
]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
||||
srt_hip = [
|
||||
"sglang[runtime_common]",
|
||||
"torch",
|
||||
"petit_kernel==0.0.2",
|
||||
"wave-lang==3.7.0",
|
||||
]
|
||||
|
||||
# https://docs.sglang.ai/platforms/cpu_server.html
|
||||
srt_cpu = ["sglang[runtime_common]", "intel-openmp"]
|
||||
|
||||
# https://docs.sglang.ai/platforms/ascend_npu.html
|
||||
srt_npu = ["sglang[runtime_common]"]
|
||||
|
||||
# xpu is not enabled in public vllm and torch whl,
|
||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||
srt_xpu = ["sglang[runtime_common]"]
|
||||
|
||||
# For Intel Gaudi(device : hpu) follow the installation guide
|
||||
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
|
||||
srt_hpu = ["sglang[runtime_common]"]
|
||||
|
||||
openai = ["openai==1.99.1", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
torch_memory_saver = ["torch_memory_saver==0.0.8"]
|
||||
decord = ["decord"]
|
||||
test = [
|
||||
"accelerate",
|
||||
"expecttest",
|
||||
"jsonlines",
|
||||
"matplotlib",
|
||||
"pandas",
|
||||
"peft",
|
||||
"sentence_transformers",
|
||||
"pytest",
|
||||
"tabulate",
|
||||
]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[torch_memory_saver]", "sglang[decord]"]
|
||||
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[decord]"]
|
||||
|
||||
dev = ["sglang[all]", "sglang[test]"]
|
||||
dev_hip = ["sglang[all_hip]", "sglang[test]"]
|
||||
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
|
||||
dev_hpu = ["sglang[all_hpu]", "sglang[test]"]
|
||||
dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"sglang" = [
|
||||
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
|
||||
"srt/layers/quantization/configs/*.json",
|
||||
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
exclude = [
|
||||
"assets*",
|
||||
"benchmark*",
|
||||
"docs*",
|
||||
"dist*",
|
||||
"playground*",
|
||||
"scripts*",
|
||||
"tests*",
|
||||
]
|
||||
|
||||
[tool.wheel]
|
||||
exclude = [
|
||||
"assets*",
|
||||
"benchmark*",
|
||||
"docs*",
|
||||
"dist*",
|
||||
"playground*",
|
||||
"scripts*",
|
||||
"tests*",
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "ans, als, hel, boostrap, childs, te, vas, hsa, ment"
|
||||
skip = "*.json,*.jsonl,*.patch,*.txt"
|
||||
16
python/sglang/README.md
Normal file
16
python/sglang/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# Code Structures
|
||||
|
||||
- `eval`: The evaluation utilities.
|
||||
- `lang`: The frontend language.
|
||||
- `srt`: The backend engine for running local models. (SRT = SGLang Runtime).
|
||||
- `test`: The test utilities.
|
||||
- `api.py`: The public APIs.
|
||||
- `bench_offline_throughput.py`: Benchmark the performance in the offline mode.
|
||||
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
|
||||
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
|
||||
- `bench_serving.py`: Benchmark online serving with dynamic requests.
|
||||
- `check_env.py`: Check the environment variables and dependencies.
|
||||
- `global_config.py`: The global configs and constants.
|
||||
- `launch_server.py`: The entry point for launching the local server.
|
||||
- `utils.py`: Common utilities.
|
||||
- `version.py`: Version info.
|
||||
83
python/sglang/__init__.py
Normal file
83
python/sglang/__init__.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# SGLang public APIs
|
||||
|
||||
# Frontend Language APIs
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.api import (
|
||||
Engine,
|
||||
Runtime,
|
||||
assistant,
|
||||
assistant_begin,
|
||||
assistant_end,
|
||||
flush_cache,
|
||||
function,
|
||||
gen,
|
||||
gen_int,
|
||||
gen_string,
|
||||
get_server_info,
|
||||
image,
|
||||
select,
|
||||
separate_reasoning,
|
||||
set_default_backend,
|
||||
system,
|
||||
system_begin,
|
||||
system_end,
|
||||
user,
|
||||
user_begin,
|
||||
user_end,
|
||||
video,
|
||||
)
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.lang.choices import (
|
||||
greedy_token_selection,
|
||||
token_length_normalized,
|
||||
unconditional_likelihood_normalized,
|
||||
)
|
||||
|
||||
# Lazy import some libraries
|
||||
from sglang.utils import LazyImport
|
||||
from sglang.version import __version__
|
||||
|
||||
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
||||
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
||||
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
||||
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
||||
|
||||
# Runtime Engine APIs
|
||||
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
||||
Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
|
||||
|
||||
__all__ = [
|
||||
"Engine",
|
||||
"Runtime",
|
||||
"assistant",
|
||||
"assistant_begin",
|
||||
"assistant_end",
|
||||
"flush_cache",
|
||||
"function",
|
||||
"gen",
|
||||
"gen_int",
|
||||
"gen_string",
|
||||
"get_server_info",
|
||||
"image",
|
||||
"select",
|
||||
"separate_reasoning",
|
||||
"set_default_backend",
|
||||
"system",
|
||||
"system_begin",
|
||||
"system_end",
|
||||
"user",
|
||||
"user_begin",
|
||||
"user_end",
|
||||
"video",
|
||||
"RuntimeEndpoint",
|
||||
"greedy_token_selection",
|
||||
"token_length_normalized",
|
||||
"unconditional_likelihood_normalized",
|
||||
"ServerArgs",
|
||||
"Anthropic",
|
||||
"LiteLLM",
|
||||
"OpenAI",
|
||||
"VertexAI",
|
||||
"global_config",
|
||||
"__version__",
|
||||
]
|
||||
452
python/sglang/bench_offline_throughput.py
Normal file
452
python/sglang/bench_offline_throughput.py
Normal file
@@ -0,0 +1,452 @@
|
||||
"""
|
||||
Benchmark the throughput in the offline mode.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (the same as bench_serving.py).
|
||||
|
||||
# Usage
|
||||
## Sharegpt dataset with default args
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10
|
||||
|
||||
## Random dataset with default args
|
||||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.bench_serving import (
|
||||
DatasetRow,
|
||||
get_dataset,
|
||||
get_tokenizer,
|
||||
sample_random_requests,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.lang.backend.runtime_endpoint import Runtime
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
backend: str = "engine"
|
||||
result_filename: str = ""
|
||||
dataset_name: str = "sharegpt"
|
||||
dataset_path: str = ""
|
||||
num_prompts: int = 1000
|
||||
sharegpt_output_len: Optional[int] = None
|
||||
sharegpt_context_len: Optional[int] = None
|
||||
random_input_len: int = 1024
|
||||
random_output_len: int = 1024
|
||||
random_range_ratio: float = 0.0
|
||||
gsp_num_groups: int = 64
|
||||
gsp_prompts_per_group: int = 16
|
||||
gsp_system_prompt_len: int = 2048
|
||||
gsp_question_len: int = 128
|
||||
gsp_output_len: int = 256
|
||||
seed: int = 1
|
||||
disable_ignore_eos: bool = False
|
||||
extra_request_body: Optional[str] = None
|
||||
apply_chat_template: bool = False
|
||||
profile: bool = False
|
||||
skip_warmup: bool = False
|
||||
do_not_exit: bool = False
|
||||
prompt_suffix: str = ""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--backend", type=str, default=BenchArgs.backend)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "random", "generated-shared-prefix"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default="", help="Path to the dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=BenchArgs.num_prompts,
|
||||
help="Number of prompts to process. Default is 1000.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharegpt-output-len",
|
||||
type=int,
|
||||
default=BenchArgs.sharegpt_output_len,
|
||||
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharegpt-context-len",
|
||||
type=int,
|
||||
default=BenchArgs.sharegpt_context_len,
|
||||
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=BenchArgs.random_input_len,
|
||||
help="Number of input tokens per request, used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=BenchArgs.random_output_len,
|
||||
help="Number of output tokens per request, used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=BenchArgs.random_range_ratio,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-num-groups",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_num_groups,
|
||||
help="Number of groups with shared prefix, used"
|
||||
"only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-prompts-per-group",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_prompts_per_group,
|
||||
help="Number of prompts per group of shared prefix, used"
|
||||
"only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-system-prompt-len",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_system_prompt_len,
|
||||
help="System prompt length, used" "only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-question-len",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_question_len,
|
||||
help="Question length, used" "only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gsp-output-len",
|
||||
type=int,
|
||||
default=BenchArgs.gsp_output_len,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--disable-ignore-eos",
|
||||
action="store_true",
|
||||
help="Disable ignore EOS token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
type=str,
|
||||
default=BenchArgs.extra_request_body,
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply-chat-template",
|
||||
action="store_true",
|
||||
help="Apply chat template",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-warmup",
|
||||
action="store_true",
|
||||
help="Skip the warmup batches.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do-not-exit",
|
||||
action="store_true",
|
||||
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
|
||||
def throughput_test_once(
|
||||
backend_name: str,
|
||||
backend,
|
||||
reqs: List[DatasetRow],
|
||||
ignore_eos: bool,
|
||||
extra_request_body: Dict,
|
||||
profile: bool,
|
||||
):
|
||||
measurement_results = {
|
||||
"backend": backend_name,
|
||||
"successful_requests": len(reqs),
|
||||
"total_latency": -1,
|
||||
"total_input_tokens": sum(r.prompt_len for r in reqs),
|
||||
"total_output_tokens": -1,
|
||||
"request_throughput": -1,
|
||||
"input_throughput": -1,
|
||||
"output_throughput": -1,
|
||||
"total_throughput": -1,
|
||||
}
|
||||
|
||||
prompt = [r.prompt for r in reqs]
|
||||
sampling_params = [
|
||||
{
|
||||
"temperature": 0,
|
||||
"max_new_tokens": r.output_len,
|
||||
"ignore_eos": ignore_eos,
|
||||
**extra_request_body,
|
||||
}
|
||||
for r in reqs
|
||||
]
|
||||
|
||||
if profile:
|
||||
assert (
|
||||
"SGLANG_TORCH_PROFILER_DIR" in os.environ
|
||||
), "Please set SGLANG_TORCH_PROFILER_DIR."
|
||||
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
|
||||
backend.start_profile()
|
||||
|
||||
st = time.perf_counter()
|
||||
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
||||
latency = time.perf_counter() - st
|
||||
|
||||
if profile:
|
||||
dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
|
||||
known_files = set(os.listdir(dir))
|
||||
backend.stop_profile()
|
||||
monitor_trace_file(known_files, dir)
|
||||
|
||||
if backend_name == "runtime":
|
||||
gen_out = json.loads(gen_out)
|
||||
|
||||
server_info = backend.get_server_info()
|
||||
|
||||
measurement_results["total_latency"] = latency
|
||||
measurement_results["total_output_tokens"] = sum(
|
||||
o["meta_info"]["completion_tokens"] for o in gen_out
|
||||
)
|
||||
measurement_results["request_throughput"] = (
|
||||
measurement_results["successful_requests"] / latency
|
||||
)
|
||||
measurement_results["input_throughput"] = (
|
||||
measurement_results["total_input_tokens"] / latency
|
||||
)
|
||||
measurement_results["output_throughput"] = (
|
||||
measurement_results["total_output_tokens"] / latency
|
||||
)
|
||||
measurement_results["total_throughput"] = (
|
||||
measurement_results["total_input_tokens"]
|
||||
+ measurement_results["total_output_tokens"]
|
||||
) / latency
|
||||
|
||||
if inspect.isawaitable(server_info):
|
||||
server_info = asyncio.run(server_info)
|
||||
|
||||
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
||||
"last_gen_throughput"
|
||||
]
|
||||
|
||||
return measurement_results
|
||||
|
||||
|
||||
def monitor_trace_file(known_files, directory, interval=1):
|
||||
print(f"Monitoring {directory} for new trace files...")
|
||||
|
||||
while True:
|
||||
flag = False
|
||||
time.sleep(interval)
|
||||
current_files = set(os.listdir(directory))
|
||||
|
||||
new_files = current_files - known_files
|
||||
for new_file in new_files:
|
||||
new_file_path = os.path.join(directory, new_file)
|
||||
print(f"New file detected: {new_file}")
|
||||
|
||||
previous_size = 0
|
||||
while True:
|
||||
try:
|
||||
current_size = os.path.getsize(new_file_path)
|
||||
except FileNotFoundError:
|
||||
print(f"File {new_file} is no longer accessible.")
|
||||
break
|
||||
|
||||
if current_size > previous_size:
|
||||
previous_size = current_size
|
||||
else:
|
||||
flag = True
|
||||
break
|
||||
|
||||
time.sleep(interval)
|
||||
if flag:
|
||||
break
|
||||
|
||||
|
||||
def throughput_test(
|
||||
server_args: ServerArgs,
|
||||
bench_args: BenchArgs,
|
||||
):
|
||||
if bench_args.backend == "engine":
|
||||
backend = Engine(**dataclasses.asdict(server_args))
|
||||
if not backend:
|
||||
raise ValueError("Please provide valid engine arguments")
|
||||
elif bench_args.backend == "runtime":
|
||||
backend = Runtime(**dataclasses.asdict(server_args))
|
||||
else:
|
||||
raise ValueError('Please set backend to either "engine" or "runtime"')
|
||||
|
||||
tokenizer_id = server_args.tokenizer_path or server_args.model_path
|
||||
tokenizer = get_tokenizer(tokenizer_id)
|
||||
|
||||
# Set global environments
|
||||
set_ulimit()
|
||||
random.seed(bench_args.seed)
|
||||
np.random.seed(bench_args.seed)
|
||||
|
||||
# Parse args
|
||||
extra_request_body = {}
|
||||
if bench_args.extra_request_body:
|
||||
extra_request_body = json.loads(args.extra_request_body)
|
||||
|
||||
# Read dataset
|
||||
input_requests = get_dataset(bench_args, tokenizer)
|
||||
|
||||
warmup_requests = sample_random_requests(
|
||||
input_len=256,
|
||||
output_len=16,
|
||||
num_prompts=min(bench_args.num_prompts, 16),
|
||||
range_ratio=1.0,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=bench_args.dataset_path,
|
||||
)
|
||||
|
||||
# Warm up
|
||||
if not bench_args.skip_warmup:
|
||||
logging.info("\nWarmup...")
|
||||
throughput_test_once(
|
||||
backend_name=bench_args.backend,
|
||||
backend=backend,
|
||||
reqs=warmup_requests,
|
||||
ignore_eos=not bench_args.disable_ignore_eos,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=False,
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
logging.info("\nBenchmark...")
|
||||
result = throughput_test_once(
|
||||
backend_name=bench_args.backend,
|
||||
backend=backend,
|
||||
reqs=input_requests,
|
||||
ignore_eos=not bench_args.disable_ignore_eos,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=bench_args.profile,
|
||||
)
|
||||
backend.shutdown()
|
||||
|
||||
if bench_args.result_filename:
|
||||
with open(bench_args.result_filename, "a") as fout:
|
||||
fout.write(json.dumps(result) + "\n")
|
||||
|
||||
print(
|
||||
"\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")
|
||||
)
|
||||
print("{:<40} {:<10}".format("Backend:", result["backend"]))
|
||||
print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"]))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"]))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"]))
|
||||
print(
|
||||
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Last generation throughput (tok/s):", result["last_gen_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request throughput (req/s):", result["request_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Input token throughput (tok/s):", result["input_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Output token throughput (tok/s):", result["output_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Total token throughput (tok/s):", result["total_throughput"]
|
||||
)
|
||||
)
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# handling ModelScope model downloads
|
||||
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
|
||||
if os.path.exists(args.model_path):
|
||||
print(f"Using local model path: {args.model_path}")
|
||||
else:
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
|
||||
print(f"Using ModelScope to download model: {args.model_path}")
|
||||
|
||||
# download the model and replace args.model_path
|
||||
args.model_path = snapshot_download(
|
||||
args.model_path,
|
||||
)
|
||||
print(f"Model downloaded to: {args.model_path}")
|
||||
except Exception as e:
|
||||
print(f"ModelScope download failed: {str(e)}")
|
||||
raise e
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
throughput_test(server_args, bench_args)
|
||||
|
||||
while bench_args.do_not_exit:
|
||||
pass
|
||||
665
python/sglang/bench_one_batch.py
Normal file
665
python/sglang/bench_one_batch.py
Normal file
@@ -0,0 +1,665 @@
|
||||
"""
|
||||
Benchmark the latency of running a single static batch without a server.
|
||||
|
||||
This script does not launch a server and uses the low-level APIs.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
||||
|
||||
# Usage (latency test)
|
||||
## with dummy weights:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||
## sweep through multiple data points and store (append) the results in a jsonl file:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
||||
## run with profiling:
|
||||
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
|
||||
# Usage (correctness test):
|
||||
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||
|
||||
## Reference output (of the correctness test above, can be gpu dependent):
|
||||
input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]]
|
||||
|
||||
prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
||||
[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633],
|
||||
[ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]],
|
||||
device='cuda:0')
|
||||
|
||||
prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141],
|
||||
[-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781],
|
||||
[-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]],
|
||||
device='cuda:0')
|
||||
|
||||
========== Prompt 0 ==========
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
|
||||
|
||||
========== Prompt 1 ==========
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the
|
||||
|
||||
========== Prompt 2 ==========
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the park
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
||||
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.layers.moe import initialize_moe_config
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_bool_env_var,
|
||||
kill_process_tree,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
set_gpu_proc_affinity,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "default"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
prompt_filename: str = ""
|
||||
result_filename: str = "result.jsonl"
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
log_decode_step: int = 0
|
||||
profile: bool = False
|
||||
profile_record_shapes: bool = False
|
||||
profile_filename_prefix: str = "profile"
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--correctness-test", action="store_true")
|
||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||
parser.add_argument(
|
||||
"--log-decode-step",
|
||||
type=int,
|
||||
default=BenchArgs.log_decode_step,
|
||||
help="Log decode latency by step, default is set to zero to disable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile", action="store_true", help="Use Torch Profiler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-record-shapes",
|
||||
action="store_true",
|
||||
help="Record tensor shapes in profiling results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-filename-prefix",
|
||||
type=str,
|
||||
default=BenchArgs.profile_filename_prefix,
|
||||
help="Prefix of the profiling file names. The full profiling result file(s) be "
|
||||
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to cast the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def load_model(server_args, port_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
||||
|
||||
model_config = ModelConfig.from_server_args(server_args)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
gpu_id=tp_rank,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
moe_ep_size=server_args.ep_size,
|
||||
pp_rank=0,
|
||||
pp_size=1,
|
||||
nccl_port=port_args.nccl_port,
|
||||
server_args=server_args,
|
||||
)
|
||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
if server_args.tp_size > 1:
|
||||
dist.barrier()
|
||||
return model_runner, tokenizer
|
||||
|
||||
|
||||
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
|
||||
prompts = (
|
||||
custom_prompts
|
||||
if custom_prompts
|
||||
else [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
)
|
||||
input_ids = [tokenizer.encode(p) for p in prompts]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
assert len(input_ids[i]) > bench_args.cut_len
|
||||
|
||||
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text=prompts[i],
|
||||
origin_input_ids=tmp_input_ids,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
reqs.append(req)
|
||||
|
||||
return input_ids, reqs
|
||||
|
||||
|
||||
def prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
):
|
||||
for i in range(len(reqs)):
|
||||
req = reqs[i]
|
||||
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||
i, : bench_args.cut_len
|
||||
]
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
return reqs
|
||||
|
||||
|
||||
def prepare_synthetic_inputs_for_latency_test(
|
||||
batch_size, input_len, custom_inputs=None
|
||||
):
|
||||
input_ids = (
|
||||
custom_inputs
|
||||
if custom_inputs
|
||||
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(input_ids)):
|
||||
req = Req(
|
||||
rid=i,
|
||||
origin_input_text="",
|
||||
origin_input_ids=list(input_ids[i]),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
req.prefix_indices = []
|
||||
req.fill_ids = req.origin_input_ids
|
||||
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
reqs.append(req)
|
||||
|
||||
return reqs
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def extend(reqs, model_runner):
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
||||
tree_cache=None,
|
||||
model_config=model_runner.model_config,
|
||||
enable_overlap=False,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.output_ids = input_token_ids
|
||||
batch.prepare_for_decode()
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
||||
if require_mlp_sync(model_runner.server_args):
|
||||
Scheduler.prepare_mlp_sync_batch_raw(
|
||||
batch,
|
||||
dp_size=model_runner.server_args.dp_size,
|
||||
attn_tp_size=1,
|
||||
tp_group=model_runner.tp_group,
|
||||
get_idle_batch=None,
|
||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
speculative_num_draft_tokens=None,
|
||||
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
||||
)
|
||||
|
||||
|
||||
def _read_prompts_from_file(prompt_file, rank_print):
|
||||
"""Read custom prompts from the file specified by `--prompt-filename`."""
|
||||
if not prompt_file:
|
||||
return []
|
||||
if not os.path.exists(prompt_file):
|
||||
rank_print(
|
||||
f"Custom prompt file {prompt_file} not found. Using default inputs..."
|
||||
)
|
||||
return []
|
||||
with open(prompt_file, "r") as pf:
|
||||
return pf.readlines()
|
||||
|
||||
|
||||
def _save_profile_trace_results(profiler, filename):
|
||||
parent_dir = os.path.dirname(os.path.abspath(filename))
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
profiler.export_chrome_trace(filename)
|
||||
print(
|
||||
profiler.key_averages(group_by_input_shape=True).table(
|
||||
sort_by="self_cpu_time_total"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def correctness_test(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
# Configure the logger
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs
|
||||
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
||||
input_ids, reqs = prepare_inputs_for_correctness_test(
|
||||
bench_args, tokenizer, custom_prompts
|
||||
)
|
||||
rank_print(f"\n{input_ids=}\n")
|
||||
|
||||
if bench_args.cut_len > 0:
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print(f"prefill logits (first half): {next_token_logits} \n")
|
||||
|
||||
# Prepare extend inputs
|
||||
reqs = prepare_extend_inputs_for_correctness_test(
|
||||
bench_args, input_ids, reqs, model_runner
|
||||
)
|
||||
|
||||
# Extend (prefill w/ KV cache)
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print(f"prefill logits (final): {next_token_logits} \n")
|
||||
|
||||
# Decode
|
||||
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
||||
for _ in range(bench_args.output_len[0] - 1):
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
next_token_ids_list = next_token_ids.tolist()
|
||||
for i in range(len(reqs)):
|
||||
output_ids[i].append(next_token_ids_list[i])
|
||||
|
||||
# Print output texts
|
||||
for i in range(len(reqs)):
|
||||
rank_print(f"========== Prompt {i} ==========")
|
||||
rank_print(tokenizer.decode(output_ids[i]), "\n")
|
||||
|
||||
|
||||
def synchronize(device):
|
||||
torch.get_device_module(device).synchronize()
|
||||
|
||||
|
||||
def latency_test_run_once(
|
||||
run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
batch_size,
|
||||
input_len,
|
||||
output_len,
|
||||
device,
|
||||
log_decode_step,
|
||||
profile,
|
||||
profile_record_shapes,
|
||||
profile_filename_prefix,
|
||||
):
|
||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||
if batch_size > max_batch_size:
|
||||
rank_print(
|
||||
f"skipping ({batch_size}, {input_len}, {output_len}) due to max batch size limit"
|
||||
)
|
||||
return
|
||||
|
||||
# Clear the pools.
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool_allocator.clear()
|
||||
|
||||
measurement_results = {
|
||||
"run_name": run_name,
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
}
|
||||
|
||||
tot_latency = 0
|
||||
|
||||
profiler = None
|
||||
if profile:
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=profile_record_shapes,
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
# Prefill
|
||||
synchronize(device)
|
||||
tic = time.perf_counter()
|
||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||
synchronize(device)
|
||||
prefill_latency = time.perf_counter() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = input_len * batch_size / prefill_latency
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["prefill_latency"] = prefill_latency
|
||||
measurement_results["prefill_throughput"] = throughput
|
||||
|
||||
if profile:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
||||
)
|
||||
|
||||
# Decode
|
||||
decode_latencies = []
|
||||
for i in range(output_len - 1):
|
||||
synchronize(device)
|
||||
if profile and i == output_len / 2:
|
||||
profiler = None
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=profile_record_shapes,
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
tic = time.perf_counter()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
synchronize(device)
|
||||
latency = time.perf_counter() - tic
|
||||
tot_latency += latency
|
||||
throughput = batch_size / latency
|
||||
decode_latencies.append(latency)
|
||||
if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
|
||||
rank_print(
|
||||
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
if profile and i == output_len / 2:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
||||
)
|
||||
|
||||
# Record decode timing from 2nd output
|
||||
if output_len > 1:
|
||||
med_decode_latency = np.median(decode_latencies)
|
||||
med_decode_throughput = batch_size / med_decode_latency
|
||||
rank_print(
|
||||
f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["median_decode_latency"] = med_decode_latency
|
||||
measurement_results["median_decode_throughput"] = med_decode_throughput
|
||||
|
||||
throughput = (input_len + output_len) * batch_size / tot_latency
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["total_latency"] = tot_latency
|
||||
measurement_results["overall_throughput"] = throughput
|
||||
return measurement_results
|
||||
|
||||
|
||||
def latency_test(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
initialize_moe_config(server_args)
|
||||
|
||||
# Set CPU affinity
|
||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank)
|
||||
|
||||
# Configure the logger
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs for warm up
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(
|
||||
bench_args.batch_size[0], bench_args.input_len[0]
|
||||
)
|
||||
|
||||
# Warm up
|
||||
rank_print("Warmup ...")
|
||||
latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bench_args.batch_size[0],
|
||||
bench_args.input_len[0],
|
||||
min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
|
||||
server_args.device,
|
||||
log_decode_step=0,
|
||||
profile=False,
|
||||
profile_record_shapes=False,
|
||||
profile_filename_prefix="", # not used
|
||||
)
|
||||
|
||||
rank_print("Benchmark ...")
|
||||
|
||||
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
||||
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
|
||||
custom_input_len = len(custom_inputs)
|
||||
|
||||
# Run the sweep
|
||||
result_list = []
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
bs_aligned_inputs = []
|
||||
if custom_inputs:
|
||||
if custom_input_len == bs:
|
||||
bs_aligned_inputs = custom_inputs
|
||||
elif custom_input_len > bs:
|
||||
rank_print(
|
||||
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
|
||||
f"Using the first {bs} prompts."
|
||||
)
|
||||
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
|
||||
else:
|
||||
rank_print(
|
||||
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
|
||||
f"Pad to the desired batch_size with the last prompt."
|
||||
)
|
||||
bs_aligned_inputs = copy.deepcopy(custom_inputs)
|
||||
bs_aligned_inputs.extend(
|
||||
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
|
||||
)
|
||||
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
|
||||
ret = latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
server_args.device,
|
||||
bench_args.log_decode_step,
|
||||
bench_args.profile if tp_rank == 0 else None,
|
||||
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
||||
bench_args.profile_filename_prefix,
|
||||
)
|
||||
if ret is not None:
|
||||
result_list.append(ret)
|
||||
|
||||
# Write results in jsonlines format on rank 0.
|
||||
if tp_rank == 0 and bench_args.result_filename:
|
||||
with open(bench_args.result_filename, "a") as fout:
|
||||
for result in result_list:
|
||||
fout.write(json.dumps(result) + "\n")
|
||||
|
||||
if server_args.tp_size > 1:
|
||||
destroy_distributed_environment()
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
server_args.cuda_graph_max_bs = max(bench_args.batch_size)
|
||||
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
if server_args.model_path:
|
||||
if bench_args.correctness_test:
|
||||
work_func = correctness_test
|
||||
else:
|
||||
work_func = latency_test
|
||||
else:
|
||||
raise ValueError(
|
||||
"Provide --model-path for running the tests or "
|
||||
"provide --result-filename for plotting the results"
|
||||
)
|
||||
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
|
||||
if server_args.tp_size == 1:
|
||||
work_func(server_args, port_args, bench_args, 0)
|
||||
else:
|
||||
workers = []
|
||||
for tp_rank in range(server_args.tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=work_func,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
proc.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
main(server_args, bench_args)
|
||||
finally:
|
||||
if server_args.tp_size != 1:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
439
python/sglang/bench_one_batch_server.py
Normal file
439
python/sglang/bench_one_batch_server.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Benchmark the latency of running a single batch with a server.
|
||||
|
||||
This script launches a server and uses the HTTP interface.
|
||||
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
|
||||
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.bench_serving import get_tokenizer, sample_random_requests
|
||||
from sglang.profiler import run_profile
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import is_blackwell, kill_process_tree
|
||||
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "default"
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
temperature: float = 0.0
|
||||
return_logprob: bool = False
|
||||
client_stream_interval: int = 1
|
||||
input_len_step_percentage: float = 0.0
|
||||
result_filename: str = "result.jsonl"
|
||||
base_url: str = ""
|
||||
skip_warmup: bool = False
|
||||
show_report: bool = False
|
||||
profile: bool = False
|
||||
profile_steps: int = 3
|
||||
profile_by_stage: bool = False
|
||||
dataset_path: str = ""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, nargs="+", default=BenchArgs.input_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--client-stream-interval",
|
||||
type=int,
|
||||
default=BenchArgs.client_stream_interval,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len-step-percentage",
|
||||
type=float,
|
||||
default=BenchArgs.input_len_step_percentage,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
||||
parser.add_argument("--skip-warmup", action="store_true")
|
||||
parser.add_argument("--show-report", action="store_true")
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument(
|
||||
"--profile-steps", type=int, default=BenchArgs.profile_steps
|
||||
)
|
||||
parser.add_argument("--profile-by-stage", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=BenchArgs.dataset_path,
|
||||
help="Path to the dataset.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to cast the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def launch_server_internal(server_args):
|
||||
try:
|
||||
launch_server(server_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
|
||||
def launch_server_process(server_args: ServerArgs):
|
||||
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
||||
proc.start()
|
||||
base_url = f"http://{server_args.host}:{server_args.port}"
|
||||
timeout = 600
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||
if response.status_code == 200:
|
||||
return proc, base_url
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
def run_one_case(
|
||||
url: str,
|
||||
batch_size: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
temperature: float,
|
||||
return_logprob: bool,
|
||||
stream_interval: int,
|
||||
input_len_step_percentage: float,
|
||||
run_name: str,
|
||||
result_filename: str,
|
||||
tokenizer,
|
||||
profile: bool = False,
|
||||
profile_steps: int = 3,
|
||||
profile_by_stage: bool = False,
|
||||
dataset_path: str = "",
|
||||
):
|
||||
requests.post(url + "/flush_cache")
|
||||
input_requests = sample_random_requests(
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
num_prompts=batch_size,
|
||||
range_ratio=1.0,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=dataset_path,
|
||||
random_sample=True,
|
||||
return_text=False,
|
||||
)
|
||||
|
||||
use_structured_outputs = False
|
||||
if use_structured_outputs:
|
||||
texts = []
|
||||
for _ in range(batch_size):
|
||||
texts.append(
|
||||
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
|
||||
* 50
|
||||
+ "Assistant:"
|
||||
)
|
||||
json_schema = "$$ANY$$"
|
||||
else:
|
||||
json_schema = None
|
||||
|
||||
profile_link = None
|
||||
if profile:
|
||||
profile_link: str = run_profile(
|
||||
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"input_ids": [req.prompt for req in input_requests],
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"json_schema": json_schema,
|
||||
"stream_interval": stream_interval,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# The TTFT of the last request in the batch
|
||||
ttft = 0.0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
if "error" in data:
|
||||
raise RuntimeError(f"Request has failed. {data}.")
|
||||
|
||||
assert (
|
||||
data["meta_info"]["finish_reason"] is None
|
||||
or data["meta_info"]["finish_reason"]["type"] == "length"
|
||||
)
|
||||
if data["meta_info"]["completion_tokens"] == 1:
|
||||
ttft = time.perf_counter() - tic
|
||||
|
||||
latency = time.perf_counter() - tic
|
||||
input_throughput = batch_size * input_len / ttft
|
||||
output_throughput = batch_size * output_len / (latency - ttft)
|
||||
overall_throughput = batch_size * (input_len + output_len) / latency
|
||||
|
||||
server_info = requests.get(url + "/get_server_info").json()
|
||||
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
|
||||
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
|
||||
|
||||
print(f"batch size: {batch_size}")
|
||||
print(f"input_len: {input_len}")
|
||||
print(f"output_len: {output_len}")
|
||||
print(f"latency: {latency:.2f} s")
|
||||
print(f"ttft: {ttft:.2f} s")
|
||||
print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
|
||||
print(f"input throughput: {input_throughput:.2f} tok/s")
|
||||
if output_len != 1:
|
||||
print(f"output throughput: {output_throughput:.2f} tok/s")
|
||||
|
||||
if result_filename:
|
||||
with open(result_filename, "a") as fout:
|
||||
res = {
|
||||
"run_name": run_name,
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
"latency": round(latency, 4),
|
||||
"output_throughput": round(output_throughput, 2),
|
||||
"overall_throughput": round(overall_throughput, 2),
|
||||
"last_gen_throughput": round(last_gen_throughput, 2),
|
||||
}
|
||||
fout.write(json.dumps(res) + "\n")
|
||||
|
||||
return (
|
||||
batch_size,
|
||||
latency,
|
||||
ttft,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
profile_link if profile else None,
|
||||
)
|
||||
|
||||
|
||||
def get_report_summary(
|
||||
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
|
||||
):
|
||||
import tabulate
|
||||
|
||||
summary = (
|
||||
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
|
||||
)
|
||||
|
||||
headers = [
|
||||
"batch size",
|
||||
"latency (s)",
|
||||
"input throughput (tok/s)",
|
||||
"output throughput (tok/s)",
|
||||
"acc length",
|
||||
"ITL (ms)",
|
||||
"input cost ($/1M)",
|
||||
"output cost ($/1M)",
|
||||
]
|
||||
if bench_args.profile:
|
||||
headers.append("profile")
|
||||
rows = []
|
||||
|
||||
for (
|
||||
batch_size,
|
||||
latency,
|
||||
ttft,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
_,
|
||||
_,
|
||||
acc_length,
|
||||
trace_link,
|
||||
) in result:
|
||||
if is_blackwell():
|
||||
hourly_cost_per_gpu = 4 # $4/hour for one B200
|
||||
else:
|
||||
hourly_cost_per_gpu = 2 # $2/hour for one H100
|
||||
|
||||
hourly_cost = hourly_cost_per_gpu * server_args.tp_size
|
||||
input_util = 0.7
|
||||
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
||||
itl = 1 / (output_throughput / batch_size) * 1000
|
||||
input_cost = 1e6 / (input_throughput * input_util) / 3600 * hourly_cost
|
||||
output_cost = 1e6 / output_throughput / 3600 * hourly_cost
|
||||
row = [
|
||||
batch_size,
|
||||
latency,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
accept_length,
|
||||
itl,
|
||||
input_cost,
|
||||
output_cost,
|
||||
]
|
||||
if trace_link:
|
||||
row.append(f"[Profile]({trace_link})")
|
||||
rows.append(row)
|
||||
|
||||
summary += tabulate.tabulate(
|
||||
rows, headers=headers, tablefmt="github", floatfmt=".2f"
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
if bench_args.base_url:
|
||||
proc, base_url = None, bench_args.base_url
|
||||
else:
|
||||
proc, base_url = launch_server_process(server_args)
|
||||
|
||||
server_info = requests.get(base_url + "/get_server_info").json()
|
||||
if "tokenizer_path" in server_info:
|
||||
tokenizer_path = server_info["tokenizer_path"]
|
||||
elif "prefill" in server_info:
|
||||
tokenizer_path = server_info["prefill"][0]["tokenizer_path"]
|
||||
tokenizer = get_tokenizer(tokenizer_path)
|
||||
|
||||
# warmup
|
||||
if not bench_args.skip_warmup:
|
||||
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
||||
run_one_case(
|
||||
base_url,
|
||||
batch_size=16,
|
||||
input_len=1024,
|
||||
output_len=16,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=bench_args.dataset_path,
|
||||
)
|
||||
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
||||
|
||||
# benchmark
|
||||
result = []
|
||||
bench_result = []
|
||||
try:
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
result.append(
|
||||
run_one_case(
|
||||
base_url,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
)
|
||||
|
||||
if bench_args.profile:
|
||||
try:
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
bench_result.append(
|
||||
(
|
||||
run_one_case(
|
||||
base_url,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
profile=bench_args.profile,
|
||||
profile_steps=bench_args.profile_steps,
|
||||
profile_by_stage=bench_args.profile_by_stage,
|
||||
)[-1],
|
||||
)
|
||||
)
|
||||
result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
|
||||
except Exception as e:
|
||||
print(f"Error profiling, there will be no profile trace dump: {e}")
|
||||
finally:
|
||||
if proc:
|
||||
kill_process_tree(proc.pid)
|
||||
|
||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||
|
||||
if not bench_args.show_report:
|
||||
return
|
||||
|
||||
summary = get_report_summary(result, server_args, bench_args)
|
||||
print(summary)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(summary)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
run_benchmark(server_args, bench_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2403
python/sglang/bench_serving.py
Normal file
2403
python/sglang/bench_serving.py
Normal file
File diff suppressed because it is too large
Load Diff
305
python/sglang/check_env.py
Normal file
305
python/sglang/check_env.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Check environment configurations and dependency versions."""
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
import resource
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
|
||||
def is_cuda_v2():
|
||||
return torch.version.cuda is not None
|
||||
|
||||
|
||||
# List of packages to check versions
|
||||
PACKAGE_LIST = [
|
||||
"sglang",
|
||||
"sgl_kernel",
|
||||
"flashinfer_python",
|
||||
"triton",
|
||||
"transformers",
|
||||
"torchao",
|
||||
"numpy",
|
||||
"aiohttp",
|
||||
"fastapi",
|
||||
"hf_transfer",
|
||||
"huggingface_hub",
|
||||
"interegular",
|
||||
"modelscope",
|
||||
"orjson",
|
||||
"outlines",
|
||||
"packaging",
|
||||
"psutil",
|
||||
"pydantic",
|
||||
"python-multipart",
|
||||
"pyzmq",
|
||||
"torchao",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"vllm",
|
||||
"xgrammar",
|
||||
"openai",
|
||||
"tiktoken",
|
||||
"anthropic",
|
||||
"litellm",
|
||||
"decord",
|
||||
]
|
||||
|
||||
|
||||
def get_package_versions(packages):
|
||||
"""
|
||||
Get versions of specified packages.
|
||||
"""
|
||||
versions = {}
|
||||
for package in packages:
|
||||
package_name = package.split("==")[0].split(">=")[0].split("<=")[0]
|
||||
try:
|
||||
version = importlib.metadata.version(package_name)
|
||||
versions[package_name] = version
|
||||
except ModuleNotFoundError:
|
||||
versions[package_name] = "Module Not Found"
|
||||
return versions
|
||||
|
||||
|
||||
def get_cuda_info():
|
||||
"""
|
||||
Get CUDA-related information if available.
|
||||
"""
|
||||
if is_cuda_v2():
|
||||
cuda_info = {"CUDA available": torch.cuda.is_available()}
|
||||
|
||||
if cuda_info["CUDA available"]:
|
||||
cuda_info.update(_get_gpu_info())
|
||||
cuda_info.update(_get_cuda_version_info())
|
||||
|
||||
return cuda_info
|
||||
elif is_hip():
|
||||
cuda_info = {"ROCM available": torch.cuda.is_available()}
|
||||
|
||||
if cuda_info["ROCM available"]:
|
||||
cuda_info.update(_get_gpu_info())
|
||||
cuda_info.update(_get_cuda_version_info())
|
||||
|
||||
return cuda_info
|
||||
|
||||
|
||||
def _get_gpu_info():
|
||||
"""
|
||||
Get information about available GPUs.
|
||||
"""
|
||||
devices = defaultdict(list)
|
||||
capabilities = defaultdict(list)
|
||||
for k in range(torch.cuda.device_count()):
|
||||
devices[torch.cuda.get_device_name(k)].append(str(k))
|
||||
capability = torch.cuda.get_device_capability(k)
|
||||
capabilities[f"{capability[0]}.{capability[1]}"].append(str(k))
|
||||
|
||||
gpu_info = {}
|
||||
for name, device_ids in devices.items():
|
||||
gpu_info[f"GPU {','.join(device_ids)}"] = name
|
||||
|
||||
if len(capabilities) == 1:
|
||||
# All GPUs have the same compute capability
|
||||
cap, gpu_ids = list(capabilities.items())[0]
|
||||
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
||||
else:
|
||||
# GPUs have different compute capabilities
|
||||
for cap, gpu_ids in capabilities.items():
|
||||
gpu_info[f"GPU {','.join(gpu_ids)} Compute Capability"] = cap
|
||||
|
||||
return gpu_info
|
||||
|
||||
|
||||
def _get_cuda_version_info():
|
||||
"""
|
||||
Get CUDA version information.
|
||||
"""
|
||||
if is_cuda_v2():
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
cuda_info = {"CUDA_HOME": CUDA_HOME}
|
||||
|
||||
if CUDA_HOME and os.path.isdir(CUDA_HOME):
|
||||
cuda_info.update(_get_nvcc_info())
|
||||
cuda_info.update(_get_cuda_driver_version())
|
||||
|
||||
return cuda_info
|
||||
elif is_hip():
|
||||
from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME
|
||||
|
||||
cuda_info = {"ROCM_HOME": ROCM_HOME}
|
||||
|
||||
if ROCM_HOME and os.path.isdir(ROCM_HOME):
|
||||
cuda_info.update(_get_nvcc_info())
|
||||
cuda_info.update(_get_cuda_driver_version())
|
||||
|
||||
return cuda_info
|
||||
else:
|
||||
cuda_info = {"CUDA_HOME": ""}
|
||||
return cuda_info
|
||||
|
||||
|
||||
def _get_nvcc_info():
|
||||
"""
|
||||
Get NVCC version information.
|
||||
"""
|
||||
if is_cuda_v2():
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
try:
|
||||
nvcc = os.path.join(CUDA_HOME, "bin/nvcc")
|
||||
nvcc_output = (
|
||||
subprocess.check_output(f'"{nvcc}" -V', shell=True)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
return {
|
||||
"NVCC": nvcc_output[
|
||||
nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind(
|
||||
"Build"
|
||||
)
|
||||
].strip()
|
||||
}
|
||||
except subprocess.SubprocessError:
|
||||
return {"NVCC": "Not Available"}
|
||||
elif is_hip():
|
||||
from torch.utils.cpp_extension import ROCM_HOME
|
||||
|
||||
try:
|
||||
hipcc = os.path.join(ROCM_HOME, "bin/hipcc")
|
||||
hipcc_output = (
|
||||
subprocess.check_output(f'"{hipcc}" --version', shell=True)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
)
|
||||
return {
|
||||
"HIPCC": hipcc_output[
|
||||
hipcc_output.rfind("HIP version") : hipcc_output.rfind("AMD clang")
|
||||
].strip()
|
||||
}
|
||||
except subprocess.SubprocessError:
|
||||
return {"HIPCC": "Not Available"}
|
||||
else:
|
||||
return {"NVCC": "Not Available"}
|
||||
|
||||
|
||||
def _get_cuda_driver_version():
|
||||
"""
|
||||
Get CUDA driver version.
|
||||
"""
|
||||
versions = set()
|
||||
if is_cuda_v2():
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-gpu=driver_version",
|
||||
"--format=csv,noheader,nounits",
|
||||
]
|
||||
)
|
||||
versions = set(output.decode().strip().split("\n"))
|
||||
if len(versions) == 1:
|
||||
return {"CUDA Driver Version": versions.pop()}
|
||||
else:
|
||||
return {"CUDA Driver Versions": ", ".join(sorted(versions))}
|
||||
except subprocess.SubprocessError:
|
||||
return {"CUDA Driver Version": "Not Available"}
|
||||
elif is_hip():
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
[
|
||||
"rocm-smi",
|
||||
"--showdriverversion",
|
||||
"--csv",
|
||||
]
|
||||
)
|
||||
versions = set(output.decode().strip().split("\n"))
|
||||
versions.discard("name, value")
|
||||
ver = versions.pop()
|
||||
ver = ver.replace('"Driver version", ', "").replace('"', "")
|
||||
|
||||
return {"ROCM Driver Version": ver}
|
||||
except subprocess.SubprocessError:
|
||||
return {"ROCM Driver Version": "Not Available"}
|
||||
else:
|
||||
return {"CUDA Driver Version": "Not Available"}
|
||||
|
||||
|
||||
def get_gpu_topology():
|
||||
"""
|
||||
Get GPU topology information.
|
||||
"""
|
||||
if is_cuda_v2():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["nvidia-smi", "topo", "-m"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return "\n" + result.stdout if result.returncode == 0 else None
|
||||
except subprocess.SubprocessError:
|
||||
return None
|
||||
elif is_hip():
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rocm-smi", "--showtopotype"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return "\n" + result.stdout if result.returncode == 0 else None
|
||||
except subprocess.SubprocessError:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_hypervisor_vendor():
|
||||
try:
|
||||
output = subprocess.check_output(["lscpu"], text=True)
|
||||
for line in output.split("\n"):
|
||||
if "Hypervisor vendor:" in line:
|
||||
return line.split(":")[1].strip()
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def check_env():
|
||||
"""
|
||||
Check and print environment information.
|
||||
"""
|
||||
env_info = OrderedDict()
|
||||
env_info["Python"] = sys.version.replace("\n", "")
|
||||
env_info.update(get_cuda_info())
|
||||
env_info["PyTorch"] = torch.__version__
|
||||
env_info.update(get_package_versions(PACKAGE_LIST))
|
||||
|
||||
gpu_topo = get_gpu_topology()
|
||||
if gpu_topo:
|
||||
if is_cuda_v2():
|
||||
env_info["NVIDIA Topology"] = gpu_topo
|
||||
elif is_hip():
|
||||
env_info["AMD Topology"] = gpu_topo
|
||||
|
||||
hypervisor_vendor = get_hypervisor_vendor()
|
||||
if hypervisor_vendor:
|
||||
env_info["Hypervisor vendor"] = hypervisor_vendor
|
||||
|
||||
ulimit_soft, _ = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
env_info["ulimit soft"] = ulimit_soft
|
||||
|
||||
for k, v in env_info.items():
|
||||
print(f"{k}: {v}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_env()
|
||||
184
python/sglang/compile_deep_gemm.py
Normal file
184
python/sglang/compile_deep_gemm.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Compile DeepGEMM Kernels for a model with specify server arguments
|
||||
|
||||
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
|
||||
It accepts server arguments (the same as launch_server.py).
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.srt.warmup import warmup
|
||||
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
# Reduce warning
|
||||
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
|
||||
# Force enable deep gemm
|
||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
|
||||
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
||||
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompileArgs:
|
||||
timeout: int = 3600
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to cast the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
@warmup("compile-deep-gemm")
|
||||
async def warm_up_compile(
|
||||
disaggregation_mode: str, tokenizer_manager: TokenizerManager
|
||||
):
|
||||
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
||||
generate_req_input = GenerateReqInput(
|
||||
input_ids=[0, 1, 2, 3],
|
||||
sampling_params={
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
)
|
||||
if disaggregation_mode != "null":
|
||||
generate_req_input.bootstrap_room = 0
|
||||
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||
|
||||
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
||||
|
||||
|
||||
def launch_server_internal(server_args):
|
||||
try:
|
||||
launch_server(server_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
|
||||
def launch_server_process_and_send_one_request(
|
||||
server_args: ServerArgs, compile_args: CompileArgs
|
||||
):
|
||||
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
||||
proc.start()
|
||||
base_url = f"http://{server_args.host}:{server_args.port}"
|
||||
timeout = compile_args.timeout
|
||||
|
||||
start_time = time.perf_counter()
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
if server_args.node_rank == 0:
|
||||
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||
else:
|
||||
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
|
||||
response = requests.get(f"{base_url}/health", headers=headers)
|
||||
if response.status_code == 200:
|
||||
# Rank-0 node send a request to sync with other node and then return.
|
||||
if server_args.node_rank == 0:
|
||||
response = requests.post(
|
||||
f"{base_url}/generate",
|
||||
json={
|
||||
"input_ids": [0, 1, 2, 3],
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 8,
|
||||
"temperature": 0,
|
||||
},
|
||||
},
|
||||
timeout=600,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
error = response.json()
|
||||
raise RuntimeError(f"Sync request failed: {error}")
|
||||
# Other nodes should wait for the exit signal from Rank-0 node.
|
||||
else:
|
||||
start_time_waiting = time.perf_counter()
|
||||
while proc.is_alive():
|
||||
if time.perf_counter() - start_time_waiting < timeout:
|
||||
time.sleep(10)
|
||||
else:
|
||||
raise TimeoutError("Waiting for main node timeout!")
|
||||
return proc
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError(
|
||||
"DeepGEMM Kernels compilation timeout."
|
||||
"\n\nFeel free and please restart the command."
|
||||
)
|
||||
|
||||
|
||||
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
||||
# Disable cuda graph and torch compile to save time
|
||||
server_args.disable_cuda_graph = True
|
||||
server_args.enable_torch_compile = False
|
||||
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
||||
|
||||
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
|
||||
server_args.watchdog_timeout = compile_args.timeout
|
||||
server_args.warmups = "compile-deep-gemm"
|
||||
|
||||
|
||||
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
||||
print(
|
||||
"Begin DeepGEMM Kernels compilation...\n"
|
||||
"It may take a long time and timeout maybe raised "
|
||||
"while the compilation is still in progress.\n"
|
||||
"Just feel free to restart the command "
|
||||
"until the compilation is fully finished.\n"
|
||||
)
|
||||
|
||||
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
||||
|
||||
print("\nDeepGEMM Kernels compilation finished successfully.")
|
||||
|
||||
# Sleep for safety
|
||||
time.sleep(10)
|
||||
if proc.is_alive():
|
||||
# This is the rank0 node.
|
||||
kill_process_tree(proc.pid)
|
||||
else:
|
||||
try:
|
||||
kill_process_tree(proc.pid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
CompileArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
compile_args = CompileArgs.from_cli_args(args)
|
||||
|
||||
refine_server_args(server_args, compile_args)
|
||||
|
||||
run_compile(server_args, compile_args)
|
||||
315
python/sglang/eval/llama3_eval.py
Normal file
315
python/sglang/eval/llama3_eval.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# Adapt from https://github.com/fw-ai/llm_eval_meta
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from openai import AsyncOpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
# Mapping providers to their clients and models
|
||||
provider_to_models = {
|
||||
"b10": {
|
||||
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
},
|
||||
"oai": {
|
||||
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
},
|
||||
"sgl": {
|
||||
"8b": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"70b": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"405b": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def fetch_responses(
|
||||
client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens
|
||||
):
|
||||
output_file = os.path.join(output_dir, f"response_{index}.pkl")
|
||||
if os.path.exists(output_file):
|
||||
print(f"File {output_file} already exists, skipping.")
|
||||
return
|
||||
|
||||
async with semaphore:
|
||||
response = await client.completions.create(
|
||||
model=provider_to_models[provider][model_size],
|
||||
prompt=prompt,
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
if isinstance(response, openai.BadRequestError):
|
||||
with open(output_file, "wb") as f:
|
||||
pickle.dump("bad_response", f)
|
||||
assert isinstance(response, openai.types.completion.Completion)
|
||||
# Save response to a file
|
||||
with open(output_file, "wb") as f:
|
||||
pickle.dump(response, f)
|
||||
|
||||
|
||||
TASK_TO_MAX_TOKENS = {
|
||||
"evals__mmlu__details": 1,
|
||||
"evals__mmlu__0_shot__cot__details": 1024,
|
||||
# Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing
|
||||
"evals__mmlu_pro__details": 2048,
|
||||
"evals__gsm8k__details": 1024,
|
||||
}
|
||||
|
||||
TASK_TO_EVAL_SET = {
|
||||
"mmlu": "evals__mmlu__details",
|
||||
"mmlu_cot": "evals__mmlu__0_shot__cot__details",
|
||||
"mmlu_pro": "evals__mmlu_pro__details",
|
||||
"gsm8k": "evals__gsm8k__details",
|
||||
}
|
||||
|
||||
|
||||
class CustomAsyncHTTPXClient(httpx.AsyncClient):
|
||||
async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response:
|
||||
request.url = httpx.URL(
|
||||
f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict"
|
||||
)
|
||||
return await super().send(request, *args, **kwargs)
|
||||
|
||||
|
||||
def get_client(provider):
|
||||
if provider not in "b10":
|
||||
if os.getenv("OPENAI_API_KEY") == None:
|
||||
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
||||
return {
|
||||
"oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"),
|
||||
"b10": AsyncOpenAI(
|
||||
api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}",
|
||||
base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict",
|
||||
http_client=CustomAsyncHTTPXClient(),
|
||||
),
|
||||
"sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"),
|
||||
}[provider]
|
||||
|
||||
|
||||
# Define the benchmark function
|
||||
async def benchmark(args):
|
||||
ds = load_dataset(
|
||||
"meta-llama/Llama-3.1-405B-Instruct-evals",
|
||||
f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}",
|
||||
)
|
||||
semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks
|
||||
|
||||
if args.num_examples is None:
|
||||
args.num_examples = len(ds["latest"]["input_final_prompts"])
|
||||
prompts = ds["latest"]["input_final_prompts"][: args.num_examples]
|
||||
|
||||
# Create the output directory if it does not exist
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
tasks = []
|
||||
# Create the tasks with tqdm progress bar
|
||||
max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]]
|
||||
client = get_client(args.provider)
|
||||
for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
fetch_responses(
|
||||
client,
|
||||
f"<|begin_of_text|>{prompt[0]}",
|
||||
semaphore,
|
||||
idx,
|
||||
args.provider,
|
||||
args.model_size,
|
||||
args.output_dir,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Run the tasks with tqdm progress bar
|
||||
for future in tqdm(
|
||||
asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks"
|
||||
):
|
||||
await future
|
||||
|
||||
|
||||
def get_mmlu_answer(response):
|
||||
if response is not None:
|
||||
return response.choices[0].text.lstrip().rstrip().upper().replace(".", "")
|
||||
return None
|
||||
|
||||
|
||||
def get_mmlu_cot_answer(response):
|
||||
pattern = r"The best answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
return match.group(1).replace(".", "").replace("*", "")
|
||||
|
||||
pattern = r"the best answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
return match.group(1).replace(".", "")
|
||||
|
||||
pattern = r"The correct answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
return match.group(1).replace(".", "")
|
||||
|
||||
pattern = r"the correct answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
return match.group(1).replace(".", "")
|
||||
|
||||
|
||||
def get_answer_gsm8k(response):
|
||||
pattern = r"The final answer is (.+)\.?"
|
||||
match = re.search(pattern, response.choices[0].text)
|
||||
if match:
|
||||
s = match.group(1)
|
||||
for ok_symbol in ["%", "$"]:
|
||||
s = s.replace(ok_symbol, "")
|
||||
return s
|
||||
|
||||
|
||||
TASK_TO_ANSWER_EXTRACTOR = {
|
||||
"evals__mmlu__details": get_mmlu_answer,
|
||||
"evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer,
|
||||
"evals__gsm8k__details": get_answer_gsm8k,
|
||||
"evals__mmlu_pro__details": get_mmlu_cot_answer,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_from_task(task, response_path, model_size):
|
||||
ds_405b = load_dataset(
|
||||
f"meta-llama/Llama-3.1-405B-Instruct-evals",
|
||||
f"Llama-3.1-405B-Instruct-{task}",
|
||||
)
|
||||
ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]]
|
||||
|
||||
if "70b" in model_size or "8b" in model_size:
|
||||
if "70" in model_size:
|
||||
ref_model_ds = load_dataset(
|
||||
f"meta-llama/Llama-3.1-70B-Instruct-evals",
|
||||
f"Llama-3.1-70B-Instruct-{task}",
|
||||
)
|
||||
else:
|
||||
ref_model_ds = load_dataset(
|
||||
f"meta-llama/Llama-3.1-8B-Instruct-evals",
|
||||
f"Llama-3.1-8B-Instruct-{task}",
|
||||
)
|
||||
|
||||
hash_to_row = {}
|
||||
for row in ref_model_ds["latest"]:
|
||||
hash_to_row[row["input_final_prompts_hash"][0]] = row
|
||||
reordered_rows = []
|
||||
for prompt_hash in ds_405b_hash_order:
|
||||
reordered_rows.append(hash_to_row[prompt_hash])
|
||||
ref_model_ds["latest"] = reordered_rows
|
||||
return ref_model_ds
|
||||
|
||||
return ds_405b
|
||||
|
||||
|
||||
def analyze(task, response_path, model_size):
|
||||
ds = get_dataset_from_task(task, response_path, model_size)
|
||||
|
||||
responses = []
|
||||
total = len(ds["latest"])
|
||||
|
||||
for i in range(0, total):
|
||||
response = pickle.load(
|
||||
open(os.path.join(response_path, f"response_{i}.pkl"), "rb")
|
||||
)
|
||||
responses.append(response)
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
correct: int = 0
|
||||
total: int = 0
|
||||
meta_correct: int = 0
|
||||
|
||||
average: float = None
|
||||
|
||||
subtask_name_to_stats = defaultdict(lambda: Stats())
|
||||
|
||||
for response, ds_row in zip(responses, ds["latest"]):
|
||||
model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response)
|
||||
|
||||
subtask = ds_row["subtask_name"]
|
||||
|
||||
is_eval_correct = model_answer in ds_row["input_correct_responses"]
|
||||
if is_eval_correct:
|
||||
subtask_name_to_stats[subtask].correct += 1
|
||||
|
||||
if ds_row["is_correct"]:
|
||||
subtask_name_to_stats[subtask].meta_correct += 1
|
||||
|
||||
subtask_name_to_stats[subtask].total += 1
|
||||
|
||||
micro_stats = Stats()
|
||||
for subtask, stats in subtask_name_to_stats.items():
|
||||
stats.average = stats.correct / stats.total
|
||||
stats.meta_average = stats.meta_correct / stats.total
|
||||
|
||||
micro_stats.correct += stats.correct
|
||||
micro_stats.total += stats.total
|
||||
micro_stats.meta_correct += stats.meta_correct
|
||||
|
||||
micro_stats.average = micro_stats.correct / micro_stats.total
|
||||
micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total
|
||||
|
||||
print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()]))
|
||||
print(
|
||||
"Meta Macro average",
|
||||
np.mean([x.meta_average for x in subtask_name_to_stats.values()]),
|
||||
)
|
||||
print("Micro average", micro_stats.average)
|
||||
print("Meta Micro average", micro_stats.meta_average)
|
||||
|
||||
|
||||
# Entry point for the script
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to run model with specified parameters."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-size",
|
||||
type=str,
|
||||
default="8b",
|
||||
help="Size of the model (e.g., 8b or 70b)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="sgl",
|
||||
help="Provider name (e.g., sgl, oai, b10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-examples", type=int, default=None, help="Number of examples to process"
|
||||
)
|
||||
parser.add_argument("--concurrency", type=int, default=16)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="tmp-output-dir",
|
||||
help="Directory to save responses",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
asyncio.run(benchmark(args))
|
||||
analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size)
|
||||
shutil.rmtree("tmp-output-dir", ignore_errors=True)
|
||||
164
python/sglang/eval/loogle_eval.py
Normal file
164
python/sglang/eval/loogle_eval.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
import torch
|
||||
from bert_score import BERTScorer
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_client(api_url: str) -> openai.AsyncOpenAI:
|
||||
if os.getenv("OPENAI_API_KEY") is None:
|
||||
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
||||
return openai.AsyncOpenAI(base_url=api_url)
|
||||
|
||||
|
||||
def get_dataset():
|
||||
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
|
||||
|
||||
|
||||
async def fetch_response(
|
||||
client: openai.AsyncOpenAI,
|
||||
context: str,
|
||||
question: str,
|
||||
semaphore: asyncio.Semaphore,
|
||||
index: int,
|
||||
model: str,
|
||||
output_dir: Path,
|
||||
):
|
||||
output_file = output_dir / f"response_{index}.pkl"
|
||||
if output_file.exists():
|
||||
return
|
||||
|
||||
prompt = (
|
||||
"Please answer the question based on the long texts below.\n"
|
||||
f"{context}\n"
|
||||
f"Question: {question}\n"
|
||||
"Answer:"
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
async with semaphore:
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.0,
|
||||
max_tokens=512,
|
||||
)
|
||||
except openai.BadRequestError as e:
|
||||
with open(output_file, "wb") as f:
|
||||
pickle.dump({"error": str(e)}, f)
|
||||
return
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
pickle.dump(response, f)
|
||||
|
||||
|
||||
async def benchmark(args):
|
||||
dataset = get_dataset()
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
client = get_client(args.api_url)
|
||||
semaphore = asyncio.Semaphore(args.max_concurrency)
|
||||
|
||||
tasks: List[asyncio.Task] = []
|
||||
for idx, ex in enumerate(dataset):
|
||||
if idx >= args.num_prompts:
|
||||
break
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
fetch_response(
|
||||
client,
|
||||
ex["context"],
|
||||
ex["question"],
|
||||
semaphore,
|
||||
idx,
|
||||
args.model,
|
||||
output_dir,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for _ in tqdm(
|
||||
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
|
||||
):
|
||||
await _
|
||||
|
||||
|
||||
def analyse(args):
|
||||
dataset = get_dataset()
|
||||
output_dir = Path(args.output_dir)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
scorer = BERTScorer(lang="en", device=device)
|
||||
|
||||
hyps: List[str] = []
|
||||
refs: List[str] = []
|
||||
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
|
||||
if idx >= args.num_prompts:
|
||||
break
|
||||
pkl_file = output_dir / f"response_{idx}.pkl"
|
||||
if not pkl_file.exists():
|
||||
raise FileNotFoundError(pkl_file)
|
||||
|
||||
response = pickle.load(open(pkl_file, "rb"))
|
||||
if isinstance(response, dict) and "error" in response:
|
||||
continue
|
||||
|
||||
hyps.append(response.choices[0].message.content.strip())
|
||||
refs.append(ex["answer"])
|
||||
|
||||
if not hyps:
|
||||
print("No valid responses to score!")
|
||||
return
|
||||
|
||||
batch_size = 64
|
||||
all_f1: List[float] = []
|
||||
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
|
||||
h_batch = hyps[i : i + batch_size]
|
||||
r_batch = refs[i : i + batch_size]
|
||||
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
|
||||
all_f1.extend([float(x) for x in f1_scores])
|
||||
|
||||
avg = sum(all_f1) / len(all_f1)
|
||||
print(f"Average BERTScore (F1): {avg:.2%}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run benchmark and evaluation in one go."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-url",
|
||||
default="http://127.0.0.1:30000/v1",
|
||||
help="OpenAI‑compatible API base URL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
||||
help="Model name or ID, only used for model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=10000, help="Number of prompts to run"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(benchmark(args))
|
||||
|
||||
analyse(args)
|
||||
53
python/sglang/global_config.py
Normal file
53
python/sglang/global_config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Global configurations"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class GlobalConfig:
|
||||
"""
|
||||
Store some global constants.
|
||||
|
||||
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
|
||||
many global runtime arguments as well.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Verbosity level
|
||||
# 0: do not output anything
|
||||
# 2: output final text after every run
|
||||
self.verbosity = 0
|
||||
|
||||
# Default backend of the language
|
||||
self.default_backend = None
|
||||
|
||||
# Runtime constants: New generation token ratio estimation
|
||||
self.default_init_new_token_ratio = float(
|
||||
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
|
||||
)
|
||||
self.default_min_new_token_ratio_factor = float(
|
||||
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
|
||||
)
|
||||
self.default_new_token_ratio_decay_steps = float(
|
||||
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
|
||||
)
|
||||
self.torch_empty_cache_interval = float(
|
||||
os.environ.get(
|
||||
"SGLANG_EMPTY_CACHE_INTERVAL", -1
|
||||
) # in seconds. Set if you observe high memory accumulation over a long serving period.
|
||||
)
|
||||
# Runtime constants: others
|
||||
self.retract_decode_steps = 20
|
||||
self.flashinfer_workspace_size = os.environ.get(
|
||||
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Output tokenization configs
|
||||
self.skip_special_tokens_in_output = True
|
||||
self.spaces_between_special_tokens_in_out = True
|
||||
|
||||
# Language frontend interpreter optimization configs
|
||||
self.enable_precache_with_tracing = True
|
||||
self.enable_parallel_encoding = True
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
286
python/sglang/lang/api.py
Normal file
286
python/sglang/lang/api.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Public APIs of the language."""
|
||||
|
||||
import re
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
|
||||
from sglang.lang.ir import (
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglImage,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglSeparateReasoning,
|
||||
SglVideo,
|
||||
)
|
||||
|
||||
|
||||
def function(
|
||||
func: Optional[Callable] = None, num_api_spec_tokens: Optional[int] = None
|
||||
):
|
||||
if func:
|
||||
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
||||
|
||||
def decorator(func):
|
||||
return SglFunction(func, num_api_spec_tokens=num_api_spec_tokens)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def Runtime(*args, **kwargs):
|
||||
# Avoid importing unnecessary dependency
|
||||
from sglang.lang.backend.runtime_endpoint import Runtime
|
||||
|
||||
return Runtime(*args, **kwargs)
|
||||
|
||||
|
||||
def Engine(*args, **kwargs):
|
||||
# Avoid importing unnecessary dependency
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
|
||||
return Engine(*args, **kwargs)
|
||||
|
||||
|
||||
def set_default_backend(backend: BaseBackend):
|
||||
global_config.default_backend = backend
|
||||
|
||||
|
||||
def flush_cache(backend: Optional[BaseBackend] = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return False
|
||||
|
||||
# If backend is Runtime
|
||||
if hasattr(backend, "endpoint"):
|
||||
backend = backend.endpoint
|
||||
return backend.flush_cache()
|
||||
|
||||
|
||||
def get_server_info(backend: Optional[BaseBackend] = None):
|
||||
backend = backend or global_config.default_backend
|
||||
if backend is None:
|
||||
return None
|
||||
|
||||
# If backend is Runtime
|
||||
if hasattr(backend, "endpoint"):
|
||||
backend = backend.endpoint
|
||||
return backend.get_server_info()
|
||||
|
||||
|
||||
def gen(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[Union[type, str]] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||
regex: Optional[str] = None,
|
||||
json_schema: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
|
||||
|
||||
if choices:
|
||||
return SglSelect(
|
||||
name,
|
||||
choices,
|
||||
0.0 if temperature is None else temperature,
|
||||
token_length_normalized if choices_method is None else choices_method,
|
||||
)
|
||||
|
||||
# check regex is valid
|
||||
if regex is not None:
|
||||
try:
|
||||
re.compile(regex)
|
||||
except re.error as e:
|
||||
raise e
|
||||
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
min_tokens,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
dtype,
|
||||
regex,
|
||||
json_schema,
|
||||
)
|
||||
|
||||
|
||||
def gen_int(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
None,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
int,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def gen_string(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
None,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
min_p,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def image(expr: SglExpr):
|
||||
return SglImage(expr)
|
||||
|
||||
|
||||
def video(path: str, num_frames: int):
|
||||
return SglVideo(path, num_frames)
|
||||
|
||||
|
||||
def select(
|
||||
name: Optional[str] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
temperature: float = 0.0,
|
||||
choices_method: ChoicesSamplingMethod = token_length_normalized,
|
||||
):
|
||||
assert choices is not None
|
||||
return SglSelect(name, choices, temperature, choices_method)
|
||||
|
||||
|
||||
def _role_common(name: str, expr: Optional[SglExpr] = None):
|
||||
if expr is None:
|
||||
return SglExprList([SglRoleBegin(name), SglRoleEnd(name)])
|
||||
else:
|
||||
return SglExprList([SglRoleBegin(name), expr, SglRoleEnd(name)])
|
||||
|
||||
|
||||
def system(expr: Optional[SglExpr] = None):
|
||||
return _role_common("system", expr)
|
||||
|
||||
|
||||
def user(expr: Optional[SglExpr] = None):
|
||||
return _role_common("user", expr)
|
||||
|
||||
|
||||
def assistant(expr: Optional[SglExpr] = None):
|
||||
return _role_common("assistant", expr)
|
||||
|
||||
|
||||
def system_begin():
|
||||
return SglRoleBegin("system")
|
||||
|
||||
|
||||
def system_end():
|
||||
return SglRoleEnd("system")
|
||||
|
||||
|
||||
def user_begin():
|
||||
return SglRoleBegin("user")
|
||||
|
||||
|
||||
def user_end():
|
||||
return SglRoleEnd("user")
|
||||
|
||||
|
||||
def assistant_begin():
|
||||
return SglRoleBegin("assistant")
|
||||
|
||||
|
||||
def assistant_end():
|
||||
return SglRoleEnd("assistant")
|
||||
|
||||
|
||||
def separate_reasoning(
|
||||
expr: Optional[SglExpr] = None, model_type: Optional[str] = None
|
||||
):
|
||||
return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])
|
||||
73
python/sglang/lang/backend/anthropic.py
Normal file
73
python/sglang/lang/backend/anthropic.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as e:
|
||||
anthropic = e
|
||||
|
||||
|
||||
class Anthropic(BaseBackend):
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(anthropic, Exception):
|
||||
raise anthropic
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("claude")
|
||||
self.client = anthropic.Anthropic(*args, **kwargs)
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = ""
|
||||
|
||||
ret = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system,
|
||||
messages=messages,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
comp = ret.content[0].text
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = ""
|
||||
|
||||
with self.client.messages.stream(
|
||||
model=self.model_name,
|
||||
system=system,
|
||||
messages=messages,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
yield text, {}
|
||||
82
python/sglang/lang/backend/base_backend.py
Normal file
82
python/sglang/lang/backend/base_backend.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def __init__(self) -> None:
|
||||
self.support_concate_and_append = False
|
||||
self.chat_template = get_chat_template("default")
|
||||
|
||||
def get_model_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
pass
|
||||
|
||||
def uncache_prefix(self, rid: str):
|
||||
pass
|
||||
|
||||
def end_request(self, rid: Union[str, List[str]]):
|
||||
pass
|
||||
|
||||
def begin_program(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
|
||||
pass
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def fork_program(
|
||||
self,
|
||||
src: StreamExecutor,
|
||||
dst: List[StreamExecutor],
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||
) -> ChoicesDecision:
|
||||
raise NotImplementedError()
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
def get_server_info(self):
|
||||
pass
|
||||
90
python/sglang/lang/backend/litellm.py
Normal file
90
python/sglang/lang/backend/litellm.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Mapping, Optional
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError as e:
|
||||
litellm = e
|
||||
litellm.num_retries = 1
|
||||
|
||||
|
||||
class LiteLLM(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
chat_template=None,
|
||||
api_key=None,
|
||||
organization: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[float] = 600,
|
||||
max_retries: Optional[int] = litellm.num_retries,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(litellm, Exception):
|
||||
raise litellm
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||
model_name
|
||||
)
|
||||
|
||||
self.client_params = {
|
||||
"api_key": api_key,
|
||||
"organization": organization,
|
||||
"base_url": base_url,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries,
|
||||
"default_headers": default_headers,
|
||||
}
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
ret = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**self.client_params,
|
||||
**sampling_params.to_litellm_kwargs(),
|
||||
)
|
||||
comp = ret.choices[0].message.content
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
ret = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**self.client_params,
|
||||
**sampling_params.to_litellm_kwargs(),
|
||||
)
|
||||
for chunk in ret:
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is not None:
|
||||
yield text, {}
|
||||
475
python/sglang/lang/backend/openai.py
Normal file
475
python/sglang/lang/backend/openai.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import openai
|
||||
import tiktoken
|
||||
except ImportError as e:
|
||||
openai = tiktoken = e
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_logit_bias_int(tokenizer):
|
||||
"""Get logit bias for integer numbers."""
|
||||
int_token_ids = []
|
||||
|
||||
tokens = tokenizer._mergeable_ranks
|
||||
for token, token_id in tokens.items():
|
||||
s = tokenizer.decode([token_id])
|
||||
if all([c.isdigit() for c in s]) or s in [" "]:
|
||||
int_token_ids.append(token_id)
|
||||
if len(int_token_ids) >= 300: # OpenAI API limit
|
||||
break
|
||||
special_tokens = tokenizer._special_tokens
|
||||
mask = {t: 100 for t in int_token_ids[:299]}
|
||||
mask[special_tokens["<|endoftext|>"]] = 100
|
||||
return mask
|
||||
|
||||
|
||||
INSTRUCT_MODEL_NAMES = [
|
||||
"gpt-3.5-turbo-instruct",
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TokenUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
def reset(self):
|
||||
self.prompt_tokens = self.completion_tokens = 0
|
||||
|
||||
|
||||
class OpenAI(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
is_chat_model: Optional[bool] = None,
|
||||
chat_template: Optional[ChatTemplate] = None,
|
||||
is_azure: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(openai, Exception):
|
||||
raise openai
|
||||
|
||||
if is_azure:
|
||||
self.client = openai.AzureOpenAI(*args, **kwargs)
|
||||
else:
|
||||
self.client = openai.OpenAI(*args, **kwargs)
|
||||
|
||||
self.model_name = model_name
|
||||
try:
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
||||
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||
model_name
|
||||
)
|
||||
|
||||
if is_chat_model is not None:
|
||||
self.is_chat_model = is_chat_model
|
||||
else:
|
||||
if model_name in INSTRUCT_MODEL_NAMES:
|
||||
self.is_chat_model = False
|
||||
else:
|
||||
self.is_chat_model = True
|
||||
|
||||
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||
|
||||
# Usage
|
||||
self.token_usage = TokenUsage(0, 0)
|
||||
|
||||
# API speculative execution
|
||||
# TODO(ying): This does not support multi-threading (run_batch)
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
self.spec_max_num_tries = 3
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def _prepare_spec_execution(
|
||||
self,
|
||||
sampling_params: SglSamplingParams,
|
||||
num_api_spec_tokens: int,
|
||||
spec_var_name: str,
|
||||
):
|
||||
if "max_tokens" not in self.spec_kwargs:
|
||||
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
||||
else:
|
||||
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
||||
|
||||
params = sampling_params.to_openai_kwargs()
|
||||
for key, value in params.items():
|
||||
if key in ["stop"]:
|
||||
continue
|
||||
if key in ["max_tokens"]:
|
||||
warnings.warn(
|
||||
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
||||
)
|
||||
continue
|
||||
if key not in self.spec_kwargs:
|
||||
self.spec_kwargs[key] = value
|
||||
else:
|
||||
assert (
|
||||
value == self.spec_kwargs[key]
|
||||
), "sampling parameters should be consistent if turn on api speculative execution."
|
||||
self.spec_format.append(
|
||||
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
||||
)
|
||||
return "", {}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
spec_var_name: str = None,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if s.num_api_spec_tokens is None:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported if api speculative execution is off. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
||||
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
return self._prepare_spec_execution(
|
||||
sampling_params, s.num_api_spec_tokens, spec_var_name
|
||||
)
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
if (
|
||||
self.model_name.startswith("o1")
|
||||
or self.model_name.startswith("o3")
|
||||
or "o1" in self.model_name
|
||||
):
|
||||
kwargs.pop("max_tokens", None)
|
||||
else:
|
||||
kwargs.pop("max_completion_tokens", None)
|
||||
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
# Keep the returned list (or string) as is.
|
||||
elif sampling_params.dtype in [str, "str", "string"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_ + '"',
|
||||
stop='"',
|
||||
**kwargs,
|
||||
)
|
||||
# Wrap each element in quotes if we have a list.
|
||||
if isinstance(comp, list):
|
||||
comp = ['"' + x + '"' for x in comp]
|
||||
else:
|
||||
comp = '"' + comp + '"'
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_,
|
||||
logit_bias=self.logit_bias_int,
|
||||
stop=[" "],
|
||||
**kwargs,
|
||||
)
|
||||
# Leave as a list if that's what is returned.
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
||||
|
||||
return comp, {}
|
||||
|
||||
def spec_fill(self, value: str):
|
||||
assert self.is_chat_model
|
||||
self.spec_format.append({"text": value, "stop": None, "name": None})
|
||||
|
||||
def spec_pattern_match(self, comp):
|
||||
for i, term in enumerate(self.spec_format):
|
||||
text = term["text"]
|
||||
if text != "":
|
||||
if comp.startswith(text):
|
||||
comp = comp[len(text) :]
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
pos = comp.find(term["stop"])
|
||||
if pos != -1:
|
||||
term["text"] = comp[:pos]
|
||||
comp = comp[pos:]
|
||||
else:
|
||||
if i == len(self.spec_format) - 1:
|
||||
term["text"] = comp
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def role_end_generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
):
|
||||
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||
return
|
||||
|
||||
comp = ""
|
||||
if not all(x["name"] is None for x in self.spec_format):
|
||||
# TODO(ying): throw errors or warnings
|
||||
for i in range(self.spec_max_num_tries):
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.messages_,
|
||||
**self.spec_kwargs,
|
||||
)
|
||||
# Use a string for pattern matching.
|
||||
comp_for_match = comp[0] if isinstance(comp, list) else comp
|
||||
if self.spec_pattern_match(comp_for_match):
|
||||
break
|
||||
|
||||
for term in self.spec_format:
|
||||
s.text_ += term["text"]
|
||||
name = term["name"]
|
||||
if name is not None:
|
||||
s.variables[name] = term["text"]
|
||||
s.meta_info[name] = {}
|
||||
s.variable_event[name].set()
|
||||
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
generator = openai_completion_stream(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
return generator
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
) -> ChoicesDecision:
|
||||
"""Note: `choices_method` is not used by the OpenAI backend."""
|
||||
if self.is_chat_model:
|
||||
raise NotImplementedError(
|
||||
"select/choices is not supported for chat models. "
|
||||
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
|
||||
)
|
||||
|
||||
n_choices = len(choices)
|
||||
token_ids = [self.tokenizer.encode(x) for x in choices]
|
||||
scores = [0] * n_choices
|
||||
valid = [len(x) > 0 for x in token_ids]
|
||||
prompt_tokens = self.tokenizer.encode(s.text_)
|
||||
|
||||
max_len = max([len(x) for x in token_ids])
|
||||
for step in range(max_len):
|
||||
# Build logit bias
|
||||
logit_bias = {}
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
logit_bias[token_ids[i][step]] = 100
|
||||
|
||||
# Call API
|
||||
ret = self.client.completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt_tokens,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
ret_str = ret.choices[0].text
|
||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
||||
|
||||
# TODO:
|
||||
# 1. return logits as the scores
|
||||
# 2. compute logits of the full choice
|
||||
# 3. consider chunk-based decoding
|
||||
|
||||
# Update valid
|
||||
hit = False
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
if step == len(token_ids[i]) - 1:
|
||||
valid[i] = False
|
||||
|
||||
if ret_token == token_ids[i][step]:
|
||||
scores[i] += 1
|
||||
hit = True
|
||||
else:
|
||||
valid[i] = False
|
||||
assert hit
|
||||
|
||||
if np.sum(valid) <= 1:
|
||||
break
|
||||
|
||||
prompt_tokens.append(ret_token)
|
||||
|
||||
return ChoicesDecision(
|
||||
decision=choices[np.argmax(scores)],
|
||||
meta_info={"scores": scores},
|
||||
)
|
||||
|
||||
|
||||
def openai_completion(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
) -> Union[str, List[str]]:
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
del kwargs["ebnf"]
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
||||
if len(ret.choices) == 1:
|
||||
comp = ret.choices[0].message.content
|
||||
else:
|
||||
comp = [c.message.content for c in ret.choices]
|
||||
else:
|
||||
ret = client.completions.create(prompt=prompt, **kwargs)
|
||||
if isinstance(prompt, (list, tuple)):
|
||||
comp = [c.text for c in ret.choices]
|
||||
else:
|
||||
comp = ret.choices[0].text
|
||||
if len(ret.choices) > 1:
|
||||
comp = [c.text for c in ret.choices]
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
break
|
||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||
time.sleep(5)
|
||||
if attempt == retries - 1:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"RuntimeError {e}.")
|
||||
raise e
|
||||
|
||||
return comp
|
||||
|
||||
|
||||
def openai_completion_stream(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
del kwargs["ebnf"]
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
generator = client.chat.completions.create(
|
||||
messages=prompt,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**kwargs,
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
continue
|
||||
try:
|
||||
content = ret.choices[0].delta.content
|
||||
except IndexError:
|
||||
content = None
|
||||
yield content or "", {}
|
||||
else:
|
||||
generator = client.completions.create(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**kwargs,
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
continue
|
||||
content = ret.choices[0].text
|
||||
yield content or "", {}
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
break
|
||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||
time.sleep(5)
|
||||
if attempt == retries - 1:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"RuntimeError {e}.")
|
||||
raise e
|
||||
527
python/sglang/lang/backend/runtime_endpoint.py
Normal file
527
python/sglang/lang/backend/runtime_endpoint.py
Normal file
@@ -0,0 +1,527 @@
|
||||
import atexit
|
||||
import json
|
||||
import multiprocessing
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import (
|
||||
REGEX_BOOL,
|
||||
REGEX_FLOAT,
|
||||
REGEX_INT,
|
||||
REGEX_STR,
|
||||
SglSamplingParams,
|
||||
)
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: Optional[str] = None,
|
||||
verify: Optional[str] = None,
|
||||
chat_template_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.verify = verify
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/get_model_info",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
self.model_info = res.json()
|
||||
|
||||
if chat_template_name:
|
||||
self.chat_template = get_chat_template(chat_template_name)
|
||||
else:
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
|
||||
def flush_cache(self):
|
||||
res = http_request(
|
||||
self.base_url + "/flush_cache",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
method="POST",
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def get_server_info(self):
|
||||
res = http_request(
|
||||
self.base_url + "/get_server_info",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
return res.json()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def start_profile(self):
|
||||
res = http_request(
|
||||
self.base_url + "/start_profile",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def stop_profile(self):
|
||||
res = http_request(
|
||||
self.base_url + "/stop_profile",
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
|
||||
if sampling_params.dtype is None:
|
||||
return
|
||||
|
||||
if sampling_params.stop == ():
|
||||
sampling_params.stop = []
|
||||
|
||||
dtype_regex = None
|
||||
if sampling_params.dtype in ["int", int]:
|
||||
|
||||
dtype_regex = REGEX_INT
|
||||
sampling_params.stop.extend([" ", "\n"])
|
||||
elif sampling_params.dtype in ["float", float]:
|
||||
|
||||
dtype_regex = REGEX_FLOAT
|
||||
sampling_params.stop.extend([" ", "\n"])
|
||||
elif sampling_params.dtype in ["str", str]:
|
||||
|
||||
dtype_regex = REGEX_STR
|
||||
elif sampling_params.dtype in ["bool", bool]:
|
||||
|
||||
dtype_regex = REGEX_BOOL
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
if dtype_regex is not None and sampling_params.regex is not None:
|
||||
warnings.warn(
|
||||
f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
|
||||
)
|
||||
|
||||
sampling_params.regex = dtype_regex
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
self._handle_dtype_to_regex(sampling_params)
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
obj = res.json()
|
||||
comp = obj["text"]
|
||||
return comp, obj["meta_info"]
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
self._handle_dtype_to_regex(sampling_params)
|
||||
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
data["stream"] = True
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
stream=True,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
pos = 0
|
||||
|
||||
for chunk in res.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
chunk_text = data["text"][pos:]
|
||||
meta_info = data["meta_info"]
|
||||
pos += len(chunk_text)
|
||||
yield chunk_text, meta_info
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
) -> ChoicesDecision:
|
||||
assert temperature <= 1e-5
|
||||
|
||||
# Cache common prefix
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
obj = self._generate_http_request(s, data)
|
||||
prompt_len = obj["meta_info"]["prompt_tokens"]
|
||||
logprob_start_len = max(prompt_len - 2, 0) # For token healing
|
||||
|
||||
# Compute logprob
|
||||
data = {
|
||||
"text": [s.text_ + c for c in choices],
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 0,
|
||||
"temperature": 0,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
|
||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||
normalized_prompt_logprobs = [
|
||||
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
|
||||
for r in obj
|
||||
]
|
||||
|
||||
# Remove extra token if no token healing occurred
|
||||
for i in range(len(input_token_logprobs)):
|
||||
healed_token_str = input_token_logprobs[i][0][-1]
|
||||
if s.text_.endswith(healed_token_str):
|
||||
healed_token_logprob = input_token_logprobs[i][0][0]
|
||||
normalized_prompt_logprobs[i] = (
|
||||
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
|
||||
- healed_token_logprob
|
||||
) / (len(input_token_logprobs[i]) - 1)
|
||||
input_token_logprobs[i] = input_token_logprobs[i][1:]
|
||||
|
||||
# Compute unconditional logprobs if required
|
||||
if choices_method.requires_unconditional_logprobs:
|
||||
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
||||
data = {
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {"max_new_tokens": 0},
|
||||
"return_logprob": True,
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
unconditional_token_logprobs = [
|
||||
r["meta_info"]["input_token_logprobs"] for r in obj
|
||||
]
|
||||
else:
|
||||
unconditional_token_logprobs = None
|
||||
|
||||
return choices_method(
|
||||
choices=choices,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
output_token_logprobs=output_token_logprobs,
|
||||
unconditional_token_logprobs=unconditional_token_logprobs,
|
||||
)
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
res = http_request(
|
||||
self.base_url + "/concate_and_append_request",
|
||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def _generate_http_request(self, s: StreamExecutor, data):
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
return res.json()
|
||||
|
||||
def _add_images(self, s: StreamExecutor, data):
|
||||
if s.images_:
|
||||
assert len(s.images_) == 1, "Only support one image."
|
||||
data["image_data"] = s.images_[0][1]
|
||||
|
||||
def _assert_success(self, res):
|
||||
if res.status_code != 200:
|
||||
try:
|
||||
content = res.json()
|
||||
except json.JSONDecodeError:
|
||||
content = res.text
|
||||
raise RuntimeError(content)
|
||||
|
||||
|
||||
def compute_normalized_prompt_logprobs(input_logprobs):
|
||||
values = [x[0] for x in input_logprobs if x[0]]
|
||||
return sum(values) / len(values)
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""
|
||||
A wrapper for the HTTP server.
|
||||
This is used for launching the server in a python program without
|
||||
using the command line interface.
|
||||
|
||||
It is mainly used for the frontend language.
|
||||
You should use the Engine class if you want to do normal offline processing without the frontend language.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level: str = "error",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""See the arguments in server_args.py::ServerArgs"""
|
||||
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
|
||||
# client code without installing SRT server and its dependency if they want.
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import is_port_available
|
||||
|
||||
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
|
||||
# Pre-allocate ports
|
||||
for port in range(self.server_args.port, 40000):
|
||||
if is_port_available(port):
|
||||
break
|
||||
self.server_args.port = port
|
||||
|
||||
self.url = self.server_args.url()
|
||||
self.generate_url = self.url + "/generate"
|
||||
|
||||
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
proc = ctx.Process(
|
||||
target=launch_server,
|
||||
args=(self.server_args, pipe_writer),
|
||||
)
|
||||
proc.start()
|
||||
pipe_writer.close()
|
||||
self.pid = proc.pid
|
||||
|
||||
# Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
# TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
|
||||
try:
|
||||
init_state = pipe_reader.recv()
|
||||
except EOFError:
|
||||
init_state = ""
|
||||
|
||||
if init_state != "ready":
|
||||
self.shutdown()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
)
|
||||
|
||||
self.endpoint = RuntimeEndpoint(self.url)
|
||||
|
||||
def shutdown(self):
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
if self.pid is not None:
|
||||
kill_process_tree(self.pid)
|
||||
self.pid = None
|
||||
|
||||
def start_profile(self):
|
||||
self.endpoint.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
self.endpoint.stop_profile()
|
||||
|
||||
def cache_prefix(self, prefix: str):
|
||||
self.endpoint.cache_prefix(prefix)
|
||||
|
||||
def get_tokenizer(self):
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
return get_tokenizer(
|
||||
self.server_args.tokenizer_path,
|
||||
tokenizer_mode=self.server_args.tokenizer_mode,
|
||||
trust_remote_code=self.server_args.trust_remote_code,
|
||||
revision=self.server_args.revision,
|
||||
)
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
):
|
||||
if self.server_args.skip_tokenizer_init:
|
||||
json_data = {
|
||||
"input_ids": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
else:
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
pos = 0
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.post(self.generate_url, json=json_data) as response:
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]\n\n":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
if "text" in data:
|
||||
cur = data["text"][pos:]
|
||||
if cur:
|
||||
yield cur
|
||||
pos += len(cur)
|
||||
else:
|
||||
yield data
|
||||
|
||||
add_request = async_generate
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
sampling_params: Optional[Dict] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
|
||||
response = requests.post(
|
||||
self.url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
):
|
||||
json_data = {"text": prompt}
|
||||
response = requests.post(self.url + "/encode", json=json_data)
|
||||
return json.dumps(response.json())
|
||||
|
||||
async def get_server_info(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.url}/get_server_info") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_data = await response.json()
|
||||
raise RuntimeError(
|
||||
f"Failed to get server info. {error_data['error']['message']}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
148
python/sglang/lang/backend/vertexai.py
Normal file
148
python/sglang/lang/backend/vertexai.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerationConfig,
|
||||
GenerativeModel,
|
||||
Image,
|
||||
)
|
||||
except ImportError as e:
|
||||
GenerativeModel = e
|
||||
|
||||
|
||||
class VertexAI(BaseBackend):
|
||||
def __init__(self, model_name, safety_settings=None):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(GenerativeModel, Exception):
|
||||
raise GenerativeModel
|
||||
|
||||
project_id = os.environ["GCP_PROJECT_ID"]
|
||||
location = os.environ.get("GCP_LOCATION")
|
||||
vertexai.init(project=project_id, location=location)
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("default")
|
||||
self.safety_settings = safety_settings
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||
else:
|
||||
# single-turn
|
||||
prompt = (
|
||||
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||
if s.cur_images
|
||||
else s.text_
|
||||
)
|
||||
ret = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
|
||||
comp = ret.text
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||
else:
|
||||
# single-turn
|
||||
prompt = (
|
||||
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||
if s.cur_images
|
||||
else s.text_
|
||||
)
|
||||
generator = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
stream=True,
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
for ret in generator:
|
||||
yield ret.text, {}
|
||||
|
||||
def text_to_vertexai_input(self, text, images):
|
||||
input = []
|
||||
# split with image token
|
||||
text_segs = text.split(self.chat_template.image_token)
|
||||
for image_path, image_base64_data in images:
|
||||
text_seg = text_segs.pop(0)
|
||||
if text_seg != "":
|
||||
input.append(text_seg)
|
||||
input.append(Image.from_bytes(image_base64_data))
|
||||
text_seg = text_segs.pop(0)
|
||||
if text_seg != "":
|
||||
input.append(text_seg)
|
||||
return input
|
||||
|
||||
def messages_to_vertexai_input(self, messages):
|
||||
vertexai_message = []
|
||||
# from openai message format to vertexai message format
|
||||
for msg in messages:
|
||||
if isinstance(msg["content"], str):
|
||||
text = msg["content"]
|
||||
else:
|
||||
text = msg["content"][0]["text"]
|
||||
|
||||
if msg["role"] == "system":
|
||||
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
||||
vertexai_message.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "System prompt: " + text}],
|
||||
}
|
||||
)
|
||||
vertexai_message.append(
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [{"text": "Understood."}],
|
||||
}
|
||||
)
|
||||
continue
|
||||
if msg["role"] == "user":
|
||||
vertexai_msg = {
|
||||
"role": "user",
|
||||
"parts": [{"text": text}],
|
||||
}
|
||||
elif msg["role"] == "assistant":
|
||||
vertexai_msg = {
|
||||
"role": "model",
|
||||
"parts": [{"text": text}],
|
||||
}
|
||||
|
||||
# images
|
||||
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
||||
for image in msg["content"][1:]:
|
||||
assert image["type"] == "image_url"
|
||||
vertexai_msg["parts"].append(
|
||||
{
|
||||
"inline_data": {
|
||||
"data": image["image_url"]["url"].split(",")[1],
|
||||
"mime_type": "image/jpeg",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
vertexai_message.append(vertexai_msg)
|
||||
return vertexai_message
|
||||
662
python/sglang/lang/chat_template.py
Normal file
662
python/sglang/lang/chat_template.py
Normal file
@@ -0,0 +1,662 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
|
||||
class ChatTemplateStyle(Enum):
|
||||
PLAIN = auto()
|
||||
LLAMA2 = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplate:
|
||||
name: str
|
||||
default_system_prompt: str
|
||||
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
|
||||
stop_str: List[str] = ()
|
||||
image_token: str = "<image>"
|
||||
audio_token: str = "<audio>"
|
||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||
|
||||
def get_prefix_and_suffix(
|
||||
self, role: str, hist_messages: List[Dict]
|
||||
) -> Tuple[str, str]:
|
||||
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
|
||||
|
||||
if self.style == ChatTemplateStyle.LLAMA2:
|
||||
if role == "system" and not hist_messages:
|
||||
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
|
||||
system_prefix, system_suffix = self.role_prefix_and_suffix.get(
|
||||
"system", ("", "")
|
||||
)
|
||||
return (user_prefix + system_prefix, system_suffix)
|
||||
elif (
|
||||
role == "user"
|
||||
and len(hist_messages) == 1
|
||||
and hist_messages[0]["content"] is not None
|
||||
):
|
||||
return ("", suffix)
|
||||
|
||||
return prefix, suffix
|
||||
|
||||
def get_prompt(self, messages: List[Dict]) -> str:
|
||||
prompt = ""
|
||||
for i, message in enumerate(messages):
|
||||
role, content = message["role"], message["content"]
|
||||
if role == "system" and content is None:
|
||||
content = self.default_system_prompt
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
||||
prompt += f"{prefix}{content}{suffix}"
|
||||
return prompt
|
||||
|
||||
|
||||
chat_template_registry: Dict[str, ChatTemplate] = {}
|
||||
matching_function_registry: List[Callable] = []
|
||||
|
||||
|
||||
def register_chat_template(template):
|
||||
chat_template_registry[template.name] = template
|
||||
|
||||
|
||||
def register_chat_template_matching_function(func):
|
||||
matching_function_registry.append(func)
|
||||
|
||||
|
||||
def get_chat_template(name):
|
||||
return chat_template_registry[name]
|
||||
|
||||
|
||||
def get_chat_template_by_model_path(model_path):
|
||||
for matching_func in matching_function_registry:
|
||||
template_name = matching_func(model_path)
|
||||
if template_name is not None:
|
||||
return get_chat_template(template_name)
|
||||
return get_chat_template("default")
|
||||
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="default",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("SYSTEM:", "\n"),
|
||||
"user": ("USER:", "\n"),
|
||||
"assistant": ("ASSISTANT:", "\n"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="claude",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("\n\nHuman: ", ""),
|
||||
"assistant": ("\n\nAssistant:", ""),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="chatml-llava",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
image_token="<image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# There is default system prompt for qwen
|
||||
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="qwen",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="qwen2-vl",
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="vicuna_v1.1",
|
||||
default_system_prompt=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("USER:", " "),
|
||||
"assistant": ("ASSISTANT:", "</s>"),
|
||||
},
|
||||
image_token=" <image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-2-chat",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
|
||||
"user": ("[INST] ", " [/INST]"),
|
||||
"assistant": ("", " </s><s>"),
|
||||
},
|
||||
style=ChatTemplateStyle.LLAMA2,
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="mistral",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
|
||||
"user": ("[INST] ", " [/INST]"),
|
||||
"assistant": ("", " </s><s>"),
|
||||
},
|
||||
stop_str=("</s>",),
|
||||
image_token="[IMG]",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-3-instruct",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"user": (
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot_id|>",),
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
# https://huggingface.co/openbmb/MiniCPM-V-2_6
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="minicpmv",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("user:", " "),
|
||||
"assistant": ("assistant:", "</s>"),
|
||||
},
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="janus-pro",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"",
|
||||
"",
|
||||
),
|
||||
"User": (
|
||||
"<|User|>",
|
||||
"",
|
||||
),
|
||||
"assistant": (
|
||||
"<|Assistant|>",
|
||||
"<|end▁of▁sentence|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end▁of▁sentence|>",),
|
||||
image_token="<image_placeholder>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# https://huggingface.co/openbmb/MiniCPM-o-2_6
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="minicpmo",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", " "),
|
||||
"user": ("user:", " "),
|
||||
"assistant": ("assistant:", "</s>"),
|
||||
},
|
||||
stop_str=("<|im_end|>", "<|endoftext|>"),
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="janus",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"",
|
||||
"",
|
||||
),
|
||||
"user": (
|
||||
"<|User|>",
|
||||
"",
|
||||
),
|
||||
"assistant": (
|
||||
"<|Assistant|>",
|
||||
"<|end▁of▁sentence|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end▁of▁sentence|>",),
|
||||
image_token="<image_placeholder>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-3-instruct-llava",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"user": (
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"<|eot_id|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot_id|>",),
|
||||
image_token="<image>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="llama-4",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|header_start|>system<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
"user": (
|
||||
"<|header_start|>user<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|header_start|>assistant<|header_end|>\n\n",
|
||||
"<|eot|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|eot|>",),
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="yi-1.5",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
|
||||
"assistant": ("", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="yi-vl",
|
||||
default_system_prompt=(
|
||||
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
|
||||
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
|
||||
),
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", "\n\n"),
|
||||
"user": ("### Human:", "\n"),
|
||||
"assistant": ("### Assistant:", "\n"),
|
||||
},
|
||||
image_token=" <image_placeholder>\n",
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="gemma-it",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
|
||||
"assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="dbrx-instruct",
|
||||
default_system_prompt="You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks — remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>"),
|
||||
"user": ("\n<|im_start|>user\n", "<|im_end|>"),
|
||||
"assistant": ("\n<|im_start|>assistant\n", "<|im_end|>"),
|
||||
},
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="c4ai-command-r",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
|
||||
"<|END_OF_TURN_TOKEN|>",
|
||||
),
|
||||
"user": ("<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|END_OF_TURN_TOKEN|>"),
|
||||
"assistant": (
|
||||
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
||||
"<|END_OF_TURN_TOKEN|>",
|
||||
),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
)
|
||||
|
||||
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="internvl-2-5",
|
||||
default_system_prompt="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="interns1",
|
||||
default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="granite-3-instruct",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"<|start_of_role|>system<|end_of_role|>",
|
||||
"<|end_of_text|>",
|
||||
),
|
||||
"user": (
|
||||
"<|start_of_role|>user<|end_of_role|>",
|
||||
"<|end_of_text|>",
|
||||
),
|
||||
"assistant": (
|
||||
"<|start_of_role|>assistant<|end_of_role|>",
|
||||
"<|end_of_text|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end_of_text|>",),
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="deepseek-v3",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"",
|
||||
"",
|
||||
),
|
||||
"user": (
|
||||
"<|User|>",
|
||||
"",
|
||||
),
|
||||
"assistant": (
|
||||
"<|Assistant|>",
|
||||
"<|end▁of▁sentence|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end▁of▁sentence|>",),
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="glm-4v",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|system|>\n", "\n"),
|
||||
"user": ("<|user|>\n", "\n"),
|
||||
"assistant": ("<|assistant|>\n", "\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
|
||||
image_token="<|image|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_deepseek(model_path: str):
|
||||
if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
|
||||
r"base", model_path, re.IGNORECASE
|
||||
):
|
||||
return "deepseek-v3"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_deepseek_janus_pro(model_path: str):
|
||||
if re.search(r"janus", model_path, re.IGNORECASE):
|
||||
return "janus-pro"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_dbrx(model_path: str):
|
||||
if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
|
||||
r"instruct", model_path, re.IGNORECASE
|
||||
):
|
||||
return "dbrx-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
|
||||
return "vicuna_v1.1"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
if re.search(
|
||||
r"llama-2.*chat|codellama.*instruct",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "llama-2-chat"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_mistral(model_path: str):
|
||||
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
||||
return "mistral"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama3_instruct(model_path: str):
|
||||
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
|
||||
return "llama-3-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
if re.search(r"tinyllama", model_path, re.IGNORECASE):
|
||||
return "chatml"
|
||||
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
||||
return "qwen2-vl"
|
||||
if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
|
||||
return "glm-4v"
|
||||
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
|
||||
r"llava", model_path, re.IGNORECASE
|
||||
):
|
||||
return "qwen"
|
||||
if re.search(
|
||||
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "chatml-llava"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_yi(model_path: str):
|
||||
if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
|
||||
r"llava", model_path, re.IGNORECASE
|
||||
):
|
||||
return "yi-vl"
|
||||
elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
|
||||
return "yi-1.5"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_gemma_it(model_path: str):
|
||||
if re.search(r"gemma.*it", model_path, re.IGNORECASE):
|
||||
return "gemma-it"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_openbmb_minicpm(model_path: str):
|
||||
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
||||
return "minicpmv"
|
||||
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
|
||||
return "minicpmo"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_c4ai_command_r(model_path: str):
|
||||
if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
|
||||
return "c4ai-command-r"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_granite_instruct(model_path: str):
|
||||
if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
|
||||
return "granite-3-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_gemma3_instruct(model_path: str):
|
||||
if re.search(r"gemma-3", model_path, re.IGNORECASE):
|
||||
return "gemma-it"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_internvl_chat(model_path: str):
|
||||
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
|
||||
return "internvl-2-5"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_interns1_chat(model_path: str):
|
||||
if re.search(r"intern-s1", model_path, re.IGNORECASE):
|
||||
return "interns1"
|
||||
if re.search(r"interns1", model_path, re.IGNORECASE):
|
||||
return "interns1"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
{"role": "user", "content": "What can you do?"},
|
||||
{"role": "assistant", "content": "I can chat with you."},
|
||||
]
|
||||
|
||||
template = get_chat_template("llama-2-chat")
|
||||
print(template.get_prompt(messages))
|
||||
164
python/sglang/lang/choices.py
Normal file
164
python/sglang/lang/choices.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChoicesDecision:
|
||||
decision: str
|
||||
meta_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChoicesSamplingMethod(ABC):
|
||||
|
||||
@property
|
||||
def requires_unconditional_logprobs(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision: ...
|
||||
|
||||
|
||||
class TokenLengthNormalized(ChoicesSamplingMethod):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option with the highest token length normalized prompt logprob."""
|
||||
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
|
||||
token_length_normalized = TokenLengthNormalized()
|
||||
|
||||
|
||||
class GreedyTokenSelection(ChoicesSamplingMethod):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option based on greedy logprob selection. For overlapping options
|
||||
where one option is a subset of a longer option, extend the shorter option using
|
||||
its average logprob for comparison against the longer option."""
|
||||
|
||||
num_options = len(choices)
|
||||
max_tokens = max(len(option) for option in input_token_logprobs)
|
||||
logprob_matrix = self._build_logprob_matrix(
|
||||
input_token_logprobs, max_tokens, num_options
|
||||
)
|
||||
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
|
||||
|
||||
best_choice = choices[remaining[0]]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
"greedy_logprob_matrix": logprob_matrix.tolist(),
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
|
||||
logprob_matrix = np.zeros((num_options, max_tokens))
|
||||
for i, option in enumerate(input_token_logprobs):
|
||||
actual_logprobs = [token[0] for token in option]
|
||||
avg_logprob = np.mean(actual_logprobs)
|
||||
logprob_matrix[i, : len(option)] = actual_logprobs
|
||||
if len(option) < max_tokens:
|
||||
logprob_matrix[i, len(option) :] = avg_logprob
|
||||
return logprob_matrix
|
||||
|
||||
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
|
||||
remaining = np.arange(num_options)
|
||||
for j in range(max_tokens):
|
||||
max_logprob = np.max(logprob_matrix[remaining, j])
|
||||
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
|
||||
if len(remaining) == 1:
|
||||
break
|
||||
return remaining
|
||||
|
||||
|
||||
greedy_token_selection = GreedyTokenSelection()
|
||||
|
||||
|
||||
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
|
||||
|
||||
@property
|
||||
def requires_unconditional_logprobs(self) -> bool:
|
||||
return True
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option with the highest average token logprob once normalized by
|
||||
the unconditional token logprobs.
|
||||
|
||||
The first unconditional token logprob is assumed to be None. If so, it is
|
||||
replaced with 0 for the purposes of normalization."""
|
||||
|
||||
if unconditional_token_logprobs is None:
|
||||
raise ValueError(
|
||||
"Unconditional token logprobs are required for this method."
|
||||
)
|
||||
|
||||
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
|
||||
input_token_logprobs, unconditional_token_logprobs
|
||||
)
|
||||
|
||||
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
"unconditional_token_logprobs": unconditional_token_logprobs,
|
||||
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
|
||||
normalized_unconditional_prompt_logprobs = []
|
||||
for inputs, unconditionals in zip(
|
||||
input_token_logprobs, unconditional_token_logprobs
|
||||
):
|
||||
inputs_logprobs = np.array([token[0] for token in inputs])
|
||||
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
|
||||
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
|
||||
normalized_unconditional_prompt_logprobs.append(
|
||||
float(np.mean(inputs_logprobs - unconditionals_logprobs))
|
||||
)
|
||||
return normalized_unconditional_prompt_logprobs
|
||||
|
||||
|
||||
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()
|
||||
231
python/sglang/lang/compiler.py
Normal file
231
python/sglang/lang/compiler.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from queue import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
||||
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
|
||||
|
||||
|
||||
def compile_func(function, backend):
|
||||
tracer = function.trace(backend=backend)
|
||||
compiler = CompiledFunction(tracer, function)
|
||||
return compiler
|
||||
|
||||
|
||||
class CompiledFunction:
|
||||
def __init__(self, tracer, function):
|
||||
self.function = function
|
||||
|
||||
self.last_node = CompGraphNode(tracer.last_node)
|
||||
self.expr_to_node = {}
|
||||
self.build_graph(tracer)
|
||||
self.topological_sort()
|
||||
|
||||
def build_graph(self, tracer):
|
||||
self.nodes = [self.last_node]
|
||||
self.expr_to_node[tracer.last_node] = self.nodes[-1]
|
||||
|
||||
rename_pid = {}
|
||||
|
||||
visited = set([tracer.last_node])
|
||||
head = 0
|
||||
while head < len(self.nodes):
|
||||
cur_node = self.nodes[head]
|
||||
|
||||
# add prev node
|
||||
prev_node = cur_node.expr.prev_node
|
||||
if prev_node is not None:
|
||||
if prev_node not in visited:
|
||||
visited.add(prev_node)
|
||||
self.nodes.append(CompGraphNode(prev_node))
|
||||
self.expr_to_node[prev_node] = self.nodes[-1]
|
||||
cur_node.prev_node = self.expr_to_node[prev_node]
|
||||
self.expr_to_node[prev_node].add_next_node(cur_node)
|
||||
|
||||
# add source node
|
||||
if isinstance(cur_node.expr, SglVariable):
|
||||
if cur_node.expr.name in tracer.variables:
|
||||
source = tracer.variables[cur_node.expr.name].source
|
||||
else:
|
||||
source = cur_node.expr.source
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
self.nodes.append(CompGraphNode(source))
|
||||
self.expr_to_node[source] = self.nodes[-1]
|
||||
cur_node.source_node = self.expr_to_node[source]
|
||||
self.expr_to_node[source].add_next_node(cur_node)
|
||||
head += 1
|
||||
|
||||
# rename pid
|
||||
if cur_node.expr.pid not in rename_pid:
|
||||
rename_pid[cur_node.expr.pid] = len(rename_pid)
|
||||
cur_node.expr.pid = rename_pid[cur_node.expr.pid]
|
||||
|
||||
def topological_sort(self):
|
||||
prevd = {}
|
||||
cand = Queue()
|
||||
for x in self.nodes:
|
||||
prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
new_list = []
|
||||
while cand.qsize() > 0:
|
||||
head = cand.get()
|
||||
new_list.append(head)
|
||||
for x in head.next_nodes:
|
||||
prevd[x] -= 1
|
||||
if prevd[x] == 0:
|
||||
cand.put(x)
|
||||
self.nodes = new_list
|
||||
|
||||
def print_graph(
|
||||
self,
|
||||
):
|
||||
for node in self.nodes:
|
||||
print(node)
|
||||
|
||||
def run_internal(
|
||||
self,
|
||||
backend,
|
||||
kwargs,
|
||||
default_sampling_para,
|
||||
):
|
||||
stream_executor_ids = set([x.expr.pid for x in self.nodes])
|
||||
stream_executors = {}
|
||||
for x in stream_executor_ids:
|
||||
arguments = kwargs if x == self.last_node.expr.pid else {}
|
||||
stream_executors[x] = StreamExecutor(
|
||||
backend, arguments, default_sampling_para, None, False
|
||||
)
|
||||
for node in self.nodes:
|
||||
se_id = node.expr.pid
|
||||
expr = node.expr
|
||||
if isinstance(expr, SglVariable):
|
||||
# Make a copy for SglVariable
|
||||
expr = SglVariable(expr.name, expr.source)
|
||||
expr.source_stream_executor = stream_executors[
|
||||
node.source_node.expr.pid
|
||||
]
|
||||
elif isinstance(expr, SglArgument):
|
||||
# Substitute SglArgument
|
||||
expr = kwargs[expr.name]
|
||||
stream_executors[se_id].submit(expr)
|
||||
for stream_executor in stream_executors.values():
|
||||
stream_executor.end()
|
||||
return ProgramState(stream_executors[self.last_node.expr.pid])
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
kwargs.update(self.function.bind_arguments)
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
return self.run_internal(backend, kwargs, default_sampling_para)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
stop: Union[str, List[str]] = (),
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
):
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
)
|
||||
|
||||
# Extract prefix by tracing and cache it
|
||||
if len(batch_kwargs) > 1:
|
||||
cache_program(self.function, backend)
|
||||
|
||||
# Run all programs
|
||||
if num_threads == "auto":
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
num_threads = min(num_threads, len(batch_kwargs))
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_kwargs:
|
||||
rets.append(
|
||||
self.run_internal(backend, arguments, default_sampling_para)
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(num_threads) as executor:
|
||||
futures = []
|
||||
for arguments in batch_kwargs:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self.run_internal, backend, arguments, default_sampling_para
|
||||
)
|
||||
)
|
||||
rets = [f.result() for f in futures]
|
||||
rets[-1].sync()
|
||||
|
||||
return rets
|
||||
|
||||
|
||||
class CompGraphNode:
|
||||
def __init__(
|
||||
self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
|
||||
):
|
||||
self.expr = expr
|
||||
self.next_nodes = next_nodes or []
|
||||
self.prev_node = prev_node
|
||||
self.source_node = source_node
|
||||
|
||||
def add_next_node(self, other):
|
||||
self.next_nodes.append(other)
|
||||
|
||||
def __repr__(self):
|
||||
re = f"stream {self.expr.pid:2d}: "
|
||||
re += f"%{self.expr.node_id} = "
|
||||
if self.prev_node is not None:
|
||||
re += f"%{self.prev_node.expr.node_id} + "
|
||||
re += repr(self.expr)
|
||||
return re
|
||||
1060
python/sglang/lang/interpreter.py
Normal file
1060
python/sglang/lang/interpreter.py
Normal file
File diff suppressed because it is too large
Load Diff
635
python/sglang/lang/ir.py
Normal file
635
python/sglang/lang/ir.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""The intermediate representation."""
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.choices import ChoicesSamplingMethod
|
||||
|
||||
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
|
||||
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
|
||||
REGEX_BOOL = r"(True|False)"
|
||||
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 128
|
||||
min_new_tokens: int = 0
|
||||
n: int = 1
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
min_p: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
ignore_eos: bool = False
|
||||
return_logprob: Optional[bool] = None
|
||||
logprob_start_len: Optional[int] = (None,)
|
||||
top_logprobs_num: Optional[int] = (None,)
|
||||
return_text_in_logprobs: Optional[bool] = (None,)
|
||||
json_schema: Optional[str] = None
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
|
||||
def clone(self):
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.min_new_tokens,
|
||||
self.n,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
self.min_p,
|
||||
self.frequency_penalty,
|
||||
self.presence_penalty,
|
||||
self.ignore_eos,
|
||||
self.return_logprob,
|
||||
self.logprob_start_len,
|
||||
self.top_logprobs_num,
|
||||
self.return_text_in_logprobs,
|
||||
self.json_schema,
|
||||
)
|
||||
|
||||
def to_openai_kwargs(self):
|
||||
# OpenAI does not support top_k, so we drop it here
|
||||
if self.regex is not None:
|
||||
warnings.warn("Regular expression is not supported in the OpenAI backend.")
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"max_completion_tokens": self.max_new_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_vertexai_kwargs(self):
|
||||
if self.regex is not None:
|
||||
warnings.warn(
|
||||
"Regular expression is not supported in the VertexAI backend."
|
||||
)
|
||||
return {
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": self.max_new_tokens,
|
||||
"stop_sequences": self.stop,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k if self.top_k > 0 else None,
|
||||
}
|
||||
|
||||
def to_anthropic_kwargs(self):
|
||||
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
||||
if self.regex is not None:
|
||||
warnings.warn(
|
||||
"Regular expression is not supported in the Anthropic backend."
|
||||
)
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop_sequences": (
|
||||
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
|
||||
),
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
def to_litellm_kwargs(self):
|
||||
if self.regex is not None:
|
||||
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"min_new_tokens": self.min_new_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"min_p": self.min_p,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"regex": self.regex,
|
||||
"json_schema": self.json_schema,
|
||||
}
|
||||
|
||||
|
||||
class SglFunction:
|
||||
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
|
||||
self.func = func
|
||||
self.num_api_spec_tokens = num_api_spec_tokens
|
||||
self.bind_arguments = bind_arguments or {}
|
||||
self.pin_prefix_rid = None
|
||||
|
||||
# Parse arguments
|
||||
argspec = inspect.getfullargspec(func)
|
||||
assert argspec.args[0] == "s", 'The first argument must be "s"'
|
||||
self.arg_names = argspec.args[1:]
|
||||
self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
|
||||
|
||||
def bind(self, **kwargs):
|
||||
assert all(key in self.arg_names for key in kwargs)
|
||||
|
||||
new_bind_dict = {**self.bind_arguments, **kwargs}
|
||||
return SglFunction(self.func, bind_arguments=new_bind_dict)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args,
|
||||
max_new_tokens: int = 128,
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
use_thread: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
# avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
|
||||
if stop is None:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
return run_program(
|
||||
self,
|
||||
backend,
|
||||
args,
|
||||
kwargs,
|
||||
default_sampling_para,
|
||||
stream,
|
||||
use_thread=use_thread,
|
||||
)
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
generator_style: bool = False,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program_batch
|
||||
|
||||
if stop is None:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
return []
|
||||
if not isinstance(batch_kwargs[0], dict):
|
||||
num_programs = len(batch_kwargs)
|
||||
# change the list of argument values to dict of arg_name -> arg_value
|
||||
batch_kwargs = [
|
||||
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
||||
for arg_values in batch_kwargs
|
||||
if isinstance(arg_values, (list, tuple))
|
||||
and len(self.arg_names) - len(self.arg_defaults)
|
||||
<= len(arg_values)
|
||||
<= len(self.arg_names)
|
||||
]
|
||||
# Ensure to raise an exception if the number of arguments mismatch
|
||||
if len(batch_kwargs) != num_programs:
|
||||
raise Exception("Given arguments mismatch the SGL function signature")
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
return run_program_batch(
|
||||
self,
|
||||
backend,
|
||||
batch_kwargs,
|
||||
default_sampling_para,
|
||||
num_threads,
|
||||
progress_bar,
|
||||
generator_style=generator_style,
|
||||
)
|
||||
|
||||
def trace(self, *, backend=None, **kwargs):
|
||||
from sglang.lang.tracer import trace_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return trace_program(self, kwargs, backend)
|
||||
|
||||
def cache(self, backend=None):
|
||||
from sglang.lang.interpreter import cache_program
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
return cache_program(self, backend)
|
||||
|
||||
def compile(self, *, backend=None):
|
||||
from sglang.lang.compiler import compile_func
|
||||
|
||||
return compile_func(self, backend)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
from sglang.lang.tracer import TracingScope
|
||||
|
||||
tracing_scope = TracingScope.get_current_scope()
|
||||
if tracing_scope is None:
|
||||
return self.run(*args, **kwargs)
|
||||
else:
|
||||
kwargs["backend"] = tracing_scope.tracer_state.backend
|
||||
return self.trace(*args, **kwargs)
|
||||
|
||||
|
||||
class SglExpr:
|
||||
node_ct = 0
|
||||
|
||||
def __init__(self):
|
||||
self.node_id = SglExpr.node_ct
|
||||
self.prev_node = None
|
||||
self.pid = None
|
||||
SglExpr.node_ct += 1
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr)
|
||||
|
||||
return self.concatenate_ir(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
return self.concatenate_ir(other, self)
|
||||
|
||||
def concatenate_ir(self, a, b):
|
||||
if isinstance(a, SglExprList):
|
||||
if isinstance(b, SglExprList):
|
||||
return SglExprList(a.expr_list + b.expr_list)
|
||||
else:
|
||||
return SglExprList(a.expr_list + [b])
|
||||
elif isinstance(b, SglExprList):
|
||||
return SglExprList([a] + b.expr_list)
|
||||
|
||||
return SglExprList([a, b])
|
||||
|
||||
def print_graph_dfs(self):
|
||||
ret = [""]
|
||||
visited = set()
|
||||
|
||||
def dfs_print(x):
|
||||
if x is None or x in visited:
|
||||
return
|
||||
visited.add(x)
|
||||
|
||||
# Print dependency
|
||||
if x.prev_node is not None:
|
||||
dfs_print(x.prev_node)
|
||||
|
||||
if isinstance(x, SglExprList):
|
||||
for y in x.expr_list:
|
||||
dfs_print(y)
|
||||
# elif isinstance(x, SglRole):
|
||||
# dfs_print(x.expr)
|
||||
elif isinstance(x, SglVariable):
|
||||
dfs_print(x.source)
|
||||
|
||||
# Print the node itself
|
||||
if isinstance(x, (SglFork, SglGetForkItem)):
|
||||
ret[0] += f"%{x.node_id} = {x}\n"
|
||||
else:
|
||||
if x.prev_node is not None:
|
||||
ret[0] += (
|
||||
f"%{x.node_id} = %{x.prev_node.node_id} + " + str(x) + "\n"
|
||||
)
|
||||
else:
|
||||
ret[0] += f"%{x.node_id} = " + str(x) + "\n"
|
||||
|
||||
dfs_print(self)
|
||||
return ret[0]
|
||||
|
||||
|
||||
class SglExprList(SglExpr):
|
||||
def __init__(self, expr_list: List[SglExpr]):
|
||||
super().__init__()
|
||||
self.expr_list = expr_list
|
||||
|
||||
def __repr__(self):
|
||||
return f"ExprList({self.expr_list})"
|
||||
|
||||
|
||||
class SglArgument(SglExpr):
|
||||
def __init__(self, name: str, value: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Argument(name={self.name}, value={repr(self.value)})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.value)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.value[i]
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return self.value
|
||||
|
||||
def __format__(self, *args):
|
||||
raise TypeError(
|
||||
"Cannot put argument inside a f-string. "
|
||||
"This is not compatible with the tracer. "
|
||||
)
|
||||
|
||||
|
||||
class SglImage(SglExpr):
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglImage({self.path})"
|
||||
|
||||
|
||||
class SglVideo(SglExpr):
|
||||
def __init__(self, path: str, num_frames: int):
|
||||
self.path = path
|
||||
self.num_frames = num_frames
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SglVideo({self.path}, {self.num_frames})"
|
||||
|
||||
|
||||
class SglGen(SglExpr):
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
min_p: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
regex: Optional[str] = None,
|
||||
json_schema: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=min_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Gen('{self.name}')"
|
||||
|
||||
|
||||
class SglConstantText(SglExpr):
|
||||
def __init__(self, value: str):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"Constant({repr(self.value)})"
|
||||
|
||||
|
||||
class SglRoleBegin(SglExpr):
|
||||
def __init__(self, role: str):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleBegin({self.role})"
|
||||
|
||||
|
||||
class SglRoleEnd(SglExpr):
|
||||
def __init__(self, role: str):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
def __repr__(self):
|
||||
return f"RoleEnd({self.role})"
|
||||
|
||||
|
||||
class SglSelect(SglExpr):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.choices = choices
|
||||
self.temperature = temperature
|
||||
self.choices_method = choices_method
|
||||
|
||||
def __repr__(self):
|
||||
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
|
||||
|
||||
|
||||
class SglFork(SglExpr):
|
||||
def __init__(self, number: int, position_ids_offset=None):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
self.position_ids_offset = position_ids_offset
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Fork(%{self.prev_node.node_id}, number={self.number}, "
|
||||
f"position_ids_offset={self.position_ids_offset})"
|
||||
)
|
||||
|
||||
|
||||
class SglGetForkItem(SglExpr):
|
||||
def __init__(self, index: int):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return f"GetForkItem(%{self.prev_node.node_id}, index={self.index})"
|
||||
|
||||
|
||||
class SglVariable(SglExpr):
|
||||
def __init__(self, name: str, source):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.source = source
|
||||
|
||||
def __repr__(self):
|
||||
return f"Variable('{self.name}', source=%{self.source.node_id})"
|
||||
|
||||
|
||||
class SglVarScopeBegin(SglExpr):
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeBegin('{self.name}')"
|
||||
|
||||
|
||||
class SglVarScopeEnd(SglExpr):
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def __repr__(self):
|
||||
return f"VarScopeEnd('{self.name}')"
|
||||
|
||||
|
||||
class SglConcateAndAppend(SglExpr):
|
||||
def __init__(self, states):
|
||||
super().__init__()
|
||||
self.states = states
|
||||
|
||||
def __repr__(self):
|
||||
return f"ConcatenateAndAppend('{self.states}')"
|
||||
|
||||
|
||||
class SglCommitLazy(SglExpr):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return "CommitLazy()"
|
||||
|
||||
|
||||
class SglSeparateReasoning(SglExpr):
|
||||
def __init__(self, model_type: str, expr: SglExpr):
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
|
||||
self.expr = expr
|
||||
self.name = None
|
||||
self._process_expr(expr)
|
||||
|
||||
def process_name_for_reasoning(self, name):
|
||||
if not name:
|
||||
raise ValueError("name must be provided")
|
||||
return f"{name}_reasoning_content"
|
||||
|
||||
def _process_expr(self, expr):
|
||||
if isinstance(expr, SglGen):
|
||||
self.name = self.process_name_for_reasoning(expr.name)
|
||||
elif isinstance(expr, SglSelect):
|
||||
self.name = self.process_name_for_reasoning(expr.name)
|
||||
elif isinstance(expr, SglExprList):
|
||||
for x in expr.expr_list:
|
||||
self._process_expr(x)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SeparateReasoning(model_type={self.model_type}, name={self.name})"
|
||||
279
python/sglang/lang/tracer.py
Normal file
279
python/sglang/lang/tracer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Tracing a program."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFork,
|
||||
SglGen,
|
||||
SglGetForkItem,
|
||||
SglRoleBegin,
|
||||
SglRoleEnd,
|
||||
SglSelect,
|
||||
SglVariable,
|
||||
SglVarScopeBegin,
|
||||
SglVarScopeEnd,
|
||||
)
|
||||
|
||||
|
||||
class StopTracing(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def extract_prefix_by_tracing(program, backend):
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
|
||||
arguments = dummy_arguments
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except (StopTracing, TypeError, AttributeError):
|
||||
# Some exceptions may not be caught
|
||||
pass
|
||||
|
||||
# Run and cache prefix
|
||||
prefix = ""
|
||||
for expr in tracer.flatten_nodes():
|
||||
if isinstance(expr, SglConstantText):
|
||||
prefix += expr.value
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def trace_program(program, arguments, backend):
|
||||
# Create dummy backend
|
||||
if backend is None:
|
||||
backend = BaseBackend()
|
||||
|
||||
# Create dummy arguments
|
||||
dummy_arguments = {
|
||||
name: SglArgument(name, None)
|
||||
for name in program.arg_names
|
||||
if name not in arguments
|
||||
}
|
||||
arguments.update(dummy_arguments)
|
||||
arguments.update(program.bind_arguments)
|
||||
|
||||
# Trace
|
||||
tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
return tracer
|
||||
|
||||
|
||||
class TracerProgramState(ProgramState):
|
||||
def __init__(self, backend, arguments, only_trace_prefix):
|
||||
self.pid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.only_trace_prefix = only_trace_prefix
|
||||
|
||||
if hasattr(backend, "endpoint"):
|
||||
self.backend = backend.endpoint
|
||||
|
||||
self.nodes = []
|
||||
self.last_node = None
|
||||
self.variables = {}
|
||||
self.ret_value = None
|
||||
|
||||
# For completion
|
||||
|
||||
# For chat
|
||||
self.messages_ = []
|
||||
self.cur_role = None
|
||||
self.chat_template = self.backend.get_chat_template()
|
||||
|
||||
# For multi states
|
||||
self.child_states = []
|
||||
|
||||
cur_scope = TracingScope.get_current_scope()
|
||||
if cur_scope is not None:
|
||||
cur_scope.add_child_state(self)
|
||||
|
||||
##################################
|
||||
########### Public API ###########
|
||||
##################################
|
||||
|
||||
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
||||
assert size >= 1
|
||||
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
|
||||
fork_node = SglFork(size)
|
||||
fork_node.prev_node = self.last_node
|
||||
|
||||
states = [
|
||||
TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
|
||||
for _ in range(size)
|
||||
]
|
||||
|
||||
for i in range(size):
|
||||
node = SglGetForkItem(i)
|
||||
node.prev_node = fork_node
|
||||
states[i].last_node = node
|
||||
states[i].variables = dict(self.variables)
|
||||
states[i].messages_ = list(self.messages_)
|
||||
states[i].cur_role = self.cur_role
|
||||
states[i].chat_template = self.chat_template
|
||||
|
||||
state_group = ProgramStateGroup(states, self)
|
||||
|
||||
return state_group
|
||||
|
||||
##################################
|
||||
########## Internal API ##########
|
||||
##################################
|
||||
|
||||
def _append_node(self, other: SglExpr):
|
||||
self.nodes.append(other)
|
||||
other.prev_node = self.last_node
|
||||
self.last_node = other
|
||||
|
||||
def _execute(self, other: SglExpr):
|
||||
if isinstance(other, str):
|
||||
other = SglConstantText(other)
|
||||
|
||||
other.pid = self.pid
|
||||
|
||||
if isinstance(other, SglConstantText):
|
||||
self._execute_fill(other)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
elif isinstance(other, SglSelect):
|
||||
self._execute_select(other)
|
||||
elif isinstance(other, SglExprList):
|
||||
for x in other.expr_list:
|
||||
self._execute(x)
|
||||
elif isinstance(other, SglRoleBegin):
|
||||
self._execute_role_begin(other)
|
||||
elif isinstance(other, SglRoleEnd):
|
||||
self._execute_role_end(other)
|
||||
elif isinstance(other, SglVarScopeBegin):
|
||||
self._execute_var_scope_begin(other)
|
||||
elif isinstance(other, SglVarScopeEnd):
|
||||
self._execute_var_scope_end(other)
|
||||
else:
|
||||
if self.only_trace_prefix:
|
||||
raise StopTracing()
|
||||
else:
|
||||
self._append_node(other)
|
||||
|
||||
return self
|
||||
|
||||
def __iadd__(self, other):
|
||||
self._execute(other)
|
||||
return self
|
||||
|
||||
def _execute_fill(self, expr: SglConstantText):
|
||||
if isinstance(expr, str):
|
||||
expr = SglConstantText(expr)
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
name = (
|
||||
expr.name if expr.name is not None else "select_" + str(len(self.variables))
|
||||
)
|
||||
new_node = SglVariable(name, source=expr)
|
||||
self.variables[name] = new_node
|
||||
self._append_node(expr)
|
||||
|
||||
def _execute_role_begin(self, expr: SglRoleBegin):
|
||||
assert self.cur_role is None, "Nested roles are not allowed."
|
||||
|
||||
if len(self.messages_) == 0 and expr.role != "system":
|
||||
# Insert default system message
|
||||
default_system = self.chat_template.default_system_prompt
|
||||
if default_system:
|
||||
self._execute_role_begin(SglRoleBegin("system"))
|
||||
self._execute_fill(default_system)
|
||||
self._execute_role_end(SglRoleEnd("system"))
|
||||
|
||||
self.cur_role = expr.role
|
||||
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(prefix)
|
||||
|
||||
def _execute_role_end(self, expr: SglRoleEnd):
|
||||
prefix, suffix = self.chat_template.get_prefix_and_suffix(
|
||||
expr.role, self.messages_
|
||||
)
|
||||
|
||||
self._execute_fill(suffix)
|
||||
|
||||
self.messages_.append({"role": expr.role, "content": ""})
|
||||
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
new_node = SglVariable(expr.name, source=self.last_node)
|
||||
self.variables[expr.name] = new_node
|
||||
|
||||
def get_var(self, name):
|
||||
ret = self.arguments.get(name, None)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
v = self.variables[name]
|
||||
return SglVariable(v.name, v.source)
|
||||
|
||||
def flatten_nodes(self):
|
||||
def traverse(cur):
|
||||
if isinstance(cur, SglExprList):
|
||||
for child in cur.expr_list:
|
||||
traverse(child)
|
||||
else:
|
||||
ret.append(cur)
|
||||
|
||||
ret = []
|
||||
for x in self.nodes:
|
||||
traverse(x)
|
||||
return ret
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
|
||||
class TracingScope:
|
||||
cur_scope = None
|
||||
|
||||
def __init__(self, tracer_state: TracerProgramState):
|
||||
self.tracer_state = tracer_state
|
||||
self.last_scope = TracingScope.cur_scope
|
||||
|
||||
def __enter__(self):
|
||||
TracingScope.cur_scope = self
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
TracingScope.cur_scope = self.last_scope
|
||||
|
||||
@staticmethod
|
||||
def get_current_scope():
|
||||
return TracingScope.cur_scope
|
||||
|
||||
def add_child_state(self, state: TracerProgramState):
|
||||
cur_scope = self
|
||||
while cur_scope is not None:
|
||||
cur_scope.tracer_state.child_states.append(state)
|
||||
cur_scope = cur_scope.last_scope
|
||||
16
python/sglang/launch_server.py
Normal file
16
python/sglang/launch_server.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Launch the inference server."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import prepare_server_args
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
if __name__ == "__main__":
|
||||
server_args = prepare_server_args(sys.argv[1:])
|
||||
|
||||
try:
|
||||
launch_server(server_args)
|
||||
finally:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
166
python/sglang/profiler.py
Normal file
166
python/sglang/profiler.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Run live profiling.
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.profiler
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
PARENT_FOLDER = "/tmp/sglang-profile"
|
||||
|
||||
|
||||
def _run_profile(
|
||||
url: Optional[str],
|
||||
num_steps: int,
|
||||
activities: List[str],
|
||||
output_dir: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
profile_by_stage: bool = False,
|
||||
) -> str:
|
||||
if output_dir is None:
|
||||
output_dir = PARENT_FOLDER
|
||||
|
||||
output_dir = os.path.normpath(output_dir)
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
# Add "profile_name/timestamp" to the path.
|
||||
if profile_name:
|
||||
output_dir = output_dir / profile_name
|
||||
output_dir = output_dir / str(time.time())
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
print(f"Dump profiling traces to {output_dir}")
|
||||
print(
|
||||
f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})"
|
||||
)
|
||||
|
||||
# Dump server args.
|
||||
file_path = Path(output_dir) / "server_args.json"
|
||||
if not file_path.exists():
|
||||
response = requests.get(url + "/get_server_info")
|
||||
response.raise_for_status()
|
||||
server_args_data = response.json()
|
||||
with open(file_path, "w") as file:
|
||||
file.write(json.dumps(server_args_data))
|
||||
|
||||
# Start profiler. The API replies when all steps are processed
|
||||
# and files are generated.
|
||||
json_data = {
|
||||
"output_dir": str(output_dir),
|
||||
"num_steps": str(num_steps),
|
||||
"activities": activities,
|
||||
"profile_by_stage": profile_by_stage,
|
||||
}
|
||||
|
||||
response = requests.post(url=url + "/start_profile", json=json_data)
|
||||
response.raise_for_status()
|
||||
|
||||
trace_link = str(output_dir)
|
||||
return trace_link
|
||||
|
||||
|
||||
def run_profile(
|
||||
url: Optional[str],
|
||||
num_steps: int,
|
||||
activities: List[str],
|
||||
output_dir: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
profile_by_stage: bool = False,
|
||||
):
|
||||
# step based profile will self terminate on num_steps constraints
|
||||
link = _run_profile(
|
||||
url, num_steps, activities, output_dir, profile_name, profile_by_stage
|
||||
)
|
||||
return link
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="http://localhost:30000",
|
||||
help="Server or API base url if not using http host and port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Profile directory to dump profile traces.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of this profile run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-steps",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The number of forward steps to profile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-by-stage",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="The number of forward steps to profile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether to profile CPU activity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether to profile GPU activity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to memory usage (https://pytorch.org/memory_viz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rpd",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
activities = []
|
||||
if args.cpu:
|
||||
activities.append("CPU")
|
||||
if args.gpu:
|
||||
activities.append("GPU")
|
||||
if args.mem:
|
||||
activities.append("MEM")
|
||||
if args.rpd:
|
||||
activities.append("RPD")
|
||||
run_profile(
|
||||
args.url,
|
||||
args.num_steps,
|
||||
activities,
|
||||
args.output_dir,
|
||||
args.profile_name,
|
||||
args.profile_by_stage,
|
||||
)
|
||||
177
python/sglang/srt/_custom_ops.py
Normal file
177
python/sglang/srt/_custom_ops.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
use_vllm_custom_allreduce = get_bool_env_var(
|
||||
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
|
||||
)
|
||||
|
||||
if not is_hpu():
|
||||
# ROCm does not use vllm custom allreduce
|
||||
if use_vllm_custom_allreduce and not is_hip():
|
||||
try:
|
||||
import vllm._C
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
else:
|
||||
try:
|
||||
import sgl_kernel
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from custom_ar with %r", e)
|
||||
|
||||
|
||||
if not is_hip() and not is_npu():
|
||||
if use_vllm_custom_allreduce:
|
||||
custom_op = torch.ops._C_custom_ar
|
||||
else:
|
||||
custom_op = sgl_kernel.allreduce
|
||||
|
||||
# custom allreduce
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[torch.Tensor],
|
||||
rank_data: torch.Tensor,
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return custom_op.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)
|
||||
|
||||
def all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int,
|
||||
) -> None:
|
||||
custom_op.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
custom_op.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return custom_op.meta_size()
|
||||
|
||||
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
|
||||
return custom_op.register_buffer(fa, ipc_tensors)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
return custom_op.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
custom_op.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
else:
|
||||
# ROCM custom allreduce
|
||||
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
sgl_kernel.allreduce.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return sgl_kernel.allreduce.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
# ROCM custom quick allreduce
|
||||
|
||||
def init_custom_qr(
|
||||
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
|
||||
|
||||
def qr_get_handle(fa: int) -> torch.Tensor:
|
||||
return sgl_kernel.allreduce.qr_get_handle(fa)
|
||||
|
||||
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
||||
sgl_kernel.allreduce.qr_open_handles(fa, handles)
|
||||
|
||||
def qr_all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
quant_level: int,
|
||||
cast_bf2half: bool,
|
||||
) -> None:
|
||||
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
|
||||
|
||||
def qr_destroy(fa: int) -> None:
|
||||
sgl_kernel.allreduce.qr_destroy(fa)
|
||||
|
||||
def qr_max_size() -> int:
|
||||
return sgl_kernel.allreduce.qr_max_size()
|
||||
|
||||
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
|
||||
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: bytes,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
return sgl_kernel.allreduce.mscclpp_init_context(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
context_selection,
|
||||
)
|
||||
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
|
||||
100
python/sglang/srt/aio_rwlock.py
Normal file
100
python/sglang/srt/aio_rwlock.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class RWLock:
|
||||
def __init__(self):
|
||||
# Protects internal state
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Condition variable used to wait for state changes
|
||||
self._cond = asyncio.Condition(self._lock)
|
||||
|
||||
# Number of readers currently holding the lock
|
||||
self._readers = 0
|
||||
|
||||
# Whether a writer is currently holding the lock
|
||||
self._writer_active = False
|
||||
|
||||
# How many writers are queued waiting for a turn
|
||||
self._waiting_writers = 0
|
||||
|
||||
@property
|
||||
def reader_lock(self):
|
||||
"""
|
||||
A context manager for acquiring a shared (reader) lock.
|
||||
|
||||
Example:
|
||||
async with rwlock.reader_lock:
|
||||
# read-only access
|
||||
"""
|
||||
return _ReaderLock(self)
|
||||
|
||||
@property
|
||||
def writer_lock(self):
|
||||
"""
|
||||
A context manager for acquiring an exclusive (writer) lock.
|
||||
|
||||
Example:
|
||||
async with rwlock.writer_lock:
|
||||
# exclusive access
|
||||
"""
|
||||
return _WriterLock(self)
|
||||
|
||||
async def acquire_reader(self):
|
||||
async with self._lock:
|
||||
# Wait until there is no active writer or waiting writer
|
||||
# to ensure fairness.
|
||||
while self._writer_active or self._waiting_writers > 0:
|
||||
await self._cond.wait()
|
||||
self._readers += 1
|
||||
|
||||
async def release_reader(self):
|
||||
async with self._lock:
|
||||
self._readers -= 1
|
||||
# If this was the last reader, wake up anyone waiting
|
||||
# (potentially a writer or new readers).
|
||||
if self._readers == 0:
|
||||
self._cond.notify_all()
|
||||
|
||||
async def acquire_writer(self):
|
||||
async with self._lock:
|
||||
# Increment the count of writers waiting
|
||||
self._waiting_writers += 1
|
||||
try:
|
||||
# Wait while either a writer is active or readers are present
|
||||
while self._writer_active or self._readers > 0:
|
||||
await self._cond.wait()
|
||||
self._writer_active = True
|
||||
finally:
|
||||
# Decrement waiting writers only after we've acquired the writer lock
|
||||
self._waiting_writers -= 1
|
||||
|
||||
async def release_writer(self):
|
||||
async with self._lock:
|
||||
self._writer_active = False
|
||||
# Wake up anyone waiting (readers or writers)
|
||||
self._cond.notify_all()
|
||||
|
||||
|
||||
class _ReaderLock:
|
||||
def __init__(self, rwlock: RWLock):
|
||||
self._rwlock = rwlock
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._rwlock.acquire_reader()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self._rwlock.release_reader()
|
||||
|
||||
|
||||
class _WriterLock:
|
||||
def __init__(self, rwlock: RWLock):
|
||||
self._rwlock = rwlock
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._rwlock.acquire_writer()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self._rwlock.release_writer()
|
||||
137
python/sglang/srt/bench_utils.py
Normal file
137
python/sglang/srt/bench_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# NOTE copied and modified from DeepGEMM
|
||||
class suppress_stdout_stderr:
|
||||
def __enter__(self):
|
||||
self.outnull_file = open(os.devnull, "w")
|
||||
self.errnull_file = open(os.devnull, "w")
|
||||
|
||||
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
||||
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
||||
|
||||
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
||||
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
||||
|
||||
self.old_stdout = sys.stdout
|
||||
self.old_stderr = sys.stderr
|
||||
|
||||
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
sys.stdout = self.outnull_file
|
||||
sys.stderr = self.errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
sys.stdout = self.old_stdout
|
||||
sys.stderr = self.old_stderr
|
||||
|
||||
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
os.close(self.old_stdout_fileno)
|
||||
os.close(self.old_stderr_fileno)
|
||||
|
||||
self.outnull_file.close()
|
||||
self.errnull_file.close()
|
||||
|
||||
|
||||
# NOTE copied and modified from DeepGEMM
|
||||
def bench_kineto(
|
||||
fn,
|
||||
kernel_names,
|
||||
num_tests: int = 30,
|
||||
suppress_kineto_output: bool = False,
|
||||
trace_path: str = None,
|
||||
flush_l2: bool = True,
|
||||
with_multiple_kernels: bool = False,
|
||||
):
|
||||
# Conflict with Nsight Systems
|
||||
using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
|
||||
|
||||
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
|
||||
flush_l2_size = int(8e9 // 4)
|
||||
|
||||
# For some auto-tuning kernels with prints
|
||||
fn()
|
||||
|
||||
# Profile
|
||||
suppress = (
|
||||
suppress_stdout_stderr
|
||||
if suppress_kineto_output and not using_nsys
|
||||
else nullcontext
|
||||
)
|
||||
with suppress():
|
||||
schedule = (
|
||||
torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
||||
if not using_nsys
|
||||
else None
|
||||
)
|
||||
profiler = (
|
||||
torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
|
||||
)
|
||||
if not using_nsys
|
||||
else nullcontext()
|
||||
)
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
for _ in range(num_tests):
|
||||
if flush_l2:
|
||||
torch.empty(
|
||||
flush_l2_size, dtype=torch.int, device="cuda"
|
||||
).zero_()
|
||||
fn()
|
||||
|
||||
if not using_nsys:
|
||||
profiler.step()
|
||||
|
||||
# Return 1 if using Nsight Systems
|
||||
if using_nsys:
|
||||
return 1
|
||||
|
||||
# Parse the profiling table
|
||||
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
||||
is_tuple = isinstance(kernel_names, tuple)
|
||||
prof_lines = (
|
||||
profiler.key_averages()
|
||||
.table(sort_by="cuda_time_total", max_name_column_width=100)
|
||||
.split("\n")
|
||||
)
|
||||
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert (
|
||||
sum([name in line for line in prof_lines]) == 1
|
||||
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
profiler.export_chrome_trace(trace_path)
|
||||
|
||||
# Return average kernel times
|
||||
units = {"ms": 1e3, "us": 1e6}
|
||||
kernel_times = []
|
||||
for name in kernel_names:
|
||||
total_time = 0
|
||||
total_num = 0
|
||||
for line in prof_lines:
|
||||
if name in line:
|
||||
time_str = line.split()[-2]
|
||||
num_str = line.split()[-1]
|
||||
for unit, scale in units.items():
|
||||
if unit in time_str:
|
||||
total_time += (
|
||||
float(time_str.replace(unit, "")) / scale * int(num_str)
|
||||
)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
|
||||
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
||||
29
python/sglang/srt/configs/__init__.py
Normal file
29
python/sglang/srt/configs/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||
from sglang.srt.configs.dbrx import DbrxConfig
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
||||
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
|
||||
from sglang.srt.configs.step3_vl import (
|
||||
Step3TextConfig,
|
||||
Step3VisionEncoderConfig,
|
||||
Step3VLConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ExaoneConfig",
|
||||
"ChatGLMConfig",
|
||||
"DbrxConfig",
|
||||
"DeepseekVL2Config",
|
||||
"LongcatFlashConfig",
|
||||
"MultiModalityConfig",
|
||||
"KimiVLConfig",
|
||||
"MoonViTConfig",
|
||||
"Step3VLConfig",
|
||||
"Step3TextConfig",
|
||||
"Step3VisionEncoderConfig",
|
||||
"Qwen3NextConfig",
|
||||
]
|
||||
78
python/sglang/srt/configs/chatglm.py
Normal file
78
python/sglang/srt/configs/chatglm.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py
|
||||
|
||||
# ChatGLM2 and ChatGLM3 share the same config.
|
||||
# ChatGLM4 is officially supported by Huggingface
|
||||
# transformers >= 4.46.0 is required
|
||||
# https://huggingface.co/docs/transformers/en/model_doc/glm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
model_type = "chatglm"
|
||||
attribute_map = {
|
||||
"num_hidden_layers": "num_layers",
|
||||
"n_head_kv": "multi_query_group_num",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers=28,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=4096,
|
||||
ffn_hidden_size=13696,
|
||||
kv_channels=128,
|
||||
num_attention_heads=32,
|
||||
seq_length=2048,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layernorm_epsilon=1e-5,
|
||||
rmsnorm=True,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
post_layer_norm=True,
|
||||
add_bias_linear=False,
|
||||
add_qkv_bias=False,
|
||||
interleaved_qkv=False,
|
||||
bias_dropout_fusion=True,
|
||||
multi_query_attention=False,
|
||||
multi_query_group_num=1,
|
||||
apply_query_key_layer_scaling=True,
|
||||
attention_softmax_in_fp32=True,
|
||||
fp32_residual_connection=False,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = padded_vocab_size
|
||||
self.padded_vocab_size = padded_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.kv_channels = kv_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.seq_length = seq_length
|
||||
# It is to be compatible with long lora.
|
||||
self.max_position_embeddings = seq_length
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.rmsnorm = rmsnorm
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
apply_residual_connection_post_layernorm
|
||||
)
|
||||
self.post_layer_norm = post_layer_norm
|
||||
self.add_bias_linear = add_bias_linear
|
||||
self.add_qkv_bias = add_qkv_bias
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
self.multi_query_attention = multi_query_attention
|
||||
self.multi_query_group_num = multi_query_group_num
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
self.interleaved_qkv = interleaved_qkv
|
||||
super().__init__(**kwargs)
|
||||
279
python/sglang/srt/configs/dbrx.py
Normal file
279
python/sglang/srt/configs/dbrx.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Adapted from
|
||||
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py
|
||||
"""Dbrx configuration."""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
"""Configuration class for Dbrx Attention.
|
||||
|
||||
[`DbrxAttention`] class. It is used to instantiate attention layers
|
||||
according to the specified arguments, defining the layers architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
attn_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the attention layers.
|
||||
clip_qkv (`float`, *optional*, defaults to None):
|
||||
If not `None`, clip the queries, keys, and values in the attention layer to this value.
|
||||
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
|
||||
rope_theta (float): The base frequency for rope.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_pdrop: float = 0,
|
||||
clip_qkv: Optional[float] = None,
|
||||
kv_n_heads: int = 1,
|
||||
rope_theta: float = 10000.0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.clip_qkv = clip_qkv
|
||||
self.kv_n_heads = kv_n_heads
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
for k in ["model_type"]:
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError(f"Found unknown {kwargs=}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
if config_dict.get("model_type") == "dbrx":
|
||||
config_dict = config_dict["attn_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
"You are using a model of type %s to instantiate a model of "
|
||||
"type %s. This is not supported for all configurations of "
|
||||
"models and can yield errors.",
|
||||
config_dict["model_type"],
|
||||
cls.model_type,
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class DbrxFFNConfig(PretrainedConfig):
|
||||
"""Configuration class for Dbrx FFN.
|
||||
|
||||
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
|
||||
the specified arguments, defining the layers architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
|
||||
The dict should have a key 'name' with the value being the name of
|
||||
the activation function along with any additional keyword arguments.
|
||||
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
|
||||
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
|
||||
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
|
||||
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
|
||||
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
|
||||
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
|
||||
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
|
||||
This should only be used for benchmarking purposes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ffn_act_fn: Optional[dict] = None,
|
||||
ffn_hidden_size: int = 3584,
|
||||
moe_num_experts: int = 4,
|
||||
moe_top_k: int = 1,
|
||||
moe_jitter_eps: Optional[float] = None,
|
||||
moe_loss_weight: float = 0.01,
|
||||
moe_normalize_expert_weights: Optional[float] = 1,
|
||||
uniform_expert_assignment: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
if ffn_act_fn is None:
|
||||
ffn_act_fn = {"name": "silu"}
|
||||
self.ffn_act_fn = ffn_act_fn
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.moe_jitter_eps = moe_jitter_eps
|
||||
self.moe_loss_weight = moe_loss_weight
|
||||
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
||||
self.uniform_expert_assignment = uniform_expert_assignment
|
||||
|
||||
for k in ["model_type"]:
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError(f"Found unknown {kwargs=}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
||||
) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
if config_dict.get("model_type") == "dbrx":
|
||||
config_dict = config_dict["ffn_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
"You are using a model of type %s to instantiate a model of "
|
||||
"type %s. This is not supported for all "
|
||||
"configurations of models and can yield errors.",
|
||||
config_dict["model_type"],
|
||||
cls.model_type,
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class DbrxConfig(PretrainedConfig):
|
||||
"""Configuration class for Dbrx.
|
||||
|
||||
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
|
||||
specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
d_model (`int`, *optional*, defaults to 6144):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
n_heads (`int`, *optional*, defaults to 48):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
n_layers (`int`, *optional*, defaults to 40):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
max_seq_len (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length of the model.
|
||||
vocab_size (`int`, *optional*, defaults to 100352):
|
||||
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
|
||||
the `inputs_ids` passed when calling [`DbrxModel`].
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability applied to the attention output before combining with residual.
|
||||
emb_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the embedding layer.
|
||||
attn_config (`dict`, *optional*):
|
||||
A dictionary used to configure the model's attention module.
|
||||
ffn_config (`dict`, *optional*):
|
||||
A dictionary used to configure the model's FFN module.
|
||||
use_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||
allow the model to output the auxiliary loss. See [here]() for more details
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import DbrxConfig, DbrxModel
|
||||
|
||||
>>> # Initializing a Dbrx configuration
|
||||
>>> configuration = DbrxConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = DbrxModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "dbrx"
|
||||
attribute_map = {
|
||||
"num_attention_heads": "n_heads",
|
||||
"hidden_size": "d_model",
|
||||
"num_hidden_layers": "n_layers",
|
||||
"max_position_embeddings": "max_seq_len",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 2048,
|
||||
n_heads: int = 16,
|
||||
n_layers: int = 24,
|
||||
max_seq_len: int = 2048,
|
||||
vocab_size: int = 32000,
|
||||
resid_pdrop: float = 0.0,
|
||||
emb_pdrop: float = 0.0,
|
||||
attn_config: Optional[DbrxAttentionConfig] = None,
|
||||
ffn_config: Optional[DbrxFFNConfig] = None,
|
||||
use_cache: bool = True,
|
||||
initializer_range: float = 0.02,
|
||||
output_router_logits: bool = False,
|
||||
router_aux_loss_coef: float = 0.05,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if attn_config is None:
|
||||
self.attn_config = DbrxAttentionConfig()
|
||||
elif isinstance(attn_config, dict):
|
||||
self.attn_config = DbrxAttentionConfig(**attn_config)
|
||||
else:
|
||||
self.attn_config = attn_config
|
||||
|
||||
if ffn_config is None:
|
||||
self.ffn_config = DbrxFFNConfig()
|
||||
elif isinstance(ffn_config, dict):
|
||||
self.ffn_config = DbrxFFNConfig(**ffn_config)
|
||||
else:
|
||||
self.ffn_config = ffn_config
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.emb_pdrop = emb_pdrop
|
||||
self.use_cache = use_cache
|
||||
self.initializer_range = initializer_range
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
|
||||
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||
if tie_word_embeddings:
|
||||
raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
688
python/sglang/srt/configs/deepseekvl2.py
Normal file
688
python/sglang/srt/configs/deepseekvl2.py
Normal file
@@ -0,0 +1,688 @@
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
LlamaTokenizerFast,
|
||||
PretrainedConfig,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
|
||||
def select_best_resolution(image_size, candidate_resolutions):
|
||||
# used for cropping
|
||||
original_width, original_height = image_size
|
||||
best_fit = None
|
||||
max_effective_resolution = 0
|
||||
min_wasted_resolution = float("inf")
|
||||
|
||||
for width, height in candidate_resolutions:
|
||||
scale = min(width / original_width, height / original_height)
|
||||
downscaled_width, downscaled_height = int(original_width * scale), int(
|
||||
original_height * scale
|
||||
)
|
||||
effective_resolution = min(
|
||||
downscaled_width * downscaled_height, original_width * original_height
|
||||
)
|
||||
wasted_resolution = (width * height) - effective_resolution
|
||||
|
||||
if effective_resolution > max_effective_resolution or (
|
||||
effective_resolution == max_effective_resolution
|
||||
and wasted_resolution < min_wasted_resolution
|
||||
):
|
||||
max_effective_resolution = effective_resolution
|
||||
min_wasted_resolution = wasted_resolution
|
||||
best_fit = (width, height)
|
||||
|
||||
return best_fit
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__dict__[item]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLChatProcessorOutput(DictOutput):
|
||||
input_ids: torch.LongTensor
|
||||
target_ids: torch.LongTensor
|
||||
pixel_values: (
|
||||
torch.Tensor
|
||||
) # rename from "images" to "pixel_values" for compatibility
|
||||
images_seq_mask: torch.BoolTensor
|
||||
images_spatial_crop: torch.LongTensor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
|
||||
class ImageTransform(object):
|
||||
def __init__(
|
||||
self,
|
||||
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
||||
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
||||
normalize: bool = True,
|
||||
):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.normalize = normalize
|
||||
|
||||
# only load torchvision.transforms when needed
|
||||
try:
|
||||
import torchvision.transforms as T
|
||||
|
||||
# FIXME: add version check for gguf
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
|
||||
) from err
|
||||
|
||||
transform_pipelines = [T.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_pipelines.append(T.Normalize(mean, std))
|
||||
|
||||
self.transform = T.Compose(transform_pipelines)
|
||||
|
||||
def __call__(self, pil_img: Image.Image):
|
||||
x = self.transform(pil_img)
|
||||
return x
|
||||
|
||||
|
||||
class DeepseekVLV2Processor(ProcessorMixin):
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
attributes = ["tokenizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
candidate_resolutions: Tuple[Tuple[int, int]],
|
||||
patch_size: int,
|
||||
downsample_ratio: int,
|
||||
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
||||
normalize: bool = True,
|
||||
image_token: str = "<image>",
|
||||
pad_token: str = "<|▁pad▁|>",
|
||||
add_special_token: bool = False,
|
||||
sft_format: str = "deepseek",
|
||||
mask_prompt: bool = True,
|
||||
ignore_id: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
self.candidate_resolutions = candidate_resolutions
|
||||
self.image_size = candidate_resolutions[0][0]
|
||||
self.patch_size = patch_size
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.normalize = normalize
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
self.image_transform = ImageTransform(
|
||||
mean=image_mean, std=image_std, normalize=normalize
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
# must set this,padding side with make a difference in batch inference
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
|
||||
if tokenizer.pad_token is None:
|
||||
self.tokenizer.add_special_tokens({"pad_token": pad_token})
|
||||
|
||||
# add image token
|
||||
image_token_id = self.tokenizer.vocab.get(image_token)
|
||||
if image_token_id is None:
|
||||
special_tokens = [image_token]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
self.image_token_id = self.tokenizer.vocab.get(image_token)
|
||||
|
||||
# add five special tokens for grounding-related tasks
|
||||
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
|
||||
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
# add special tokens for SFT data
|
||||
special_tokens = ["<|User|>", "<|Assistant|>"]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
self.image_token = image_token
|
||||
self.pad_token = pad_token
|
||||
self.add_special_token = add_special_token
|
||||
self.sft_format = sft_format
|
||||
self.mask_prompt = mask_prompt
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
super().__init__(
|
||||
tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
|
||||
"""play the role of format_messages_v2 and get_images_info in the last version"""
|
||||
tokenized_data = []
|
||||
masked_tokenized_data = [] # labels
|
||||
images_list = []
|
||||
images_seq_mask = []
|
||||
images_spatial_crop = []
|
||||
|
||||
image_index = 0
|
||||
image_token_cnt = messages.count(self.image_token)
|
||||
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
|
||||
messages,
|
||||
pil_images[image_index : image_index + image_token_cnt],
|
||||
bos=True,
|
||||
eos=True,
|
||||
cropping=len(pil_images) <= 2,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
image_index = image_token_cnt
|
||||
tokenized_data += tokenized_str
|
||||
if self.mask_prompt:
|
||||
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
|
||||
else:
|
||||
masked_tokenized_data += tokenized_str
|
||||
images_list += images
|
||||
images_seq_mask += seq_mask
|
||||
images_spatial_crop += spatial_crop
|
||||
|
||||
assert len(tokenized_data) == len(
|
||||
images_seq_mask
|
||||
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
||||
|
||||
return (
|
||||
tokenized_data,
|
||||
masked_tokenized_data,
|
||||
images_list,
|
||||
images_seq_mask,
|
||||
images_spatial_crop,
|
||||
)
|
||||
|
||||
@property
|
||||
def bos_id(self):
|
||||
return self.tokenizer.bos_token_id
|
||||
|
||||
@property
|
||||
def eos_id(self):
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
def encode(self, text: str, bos: bool = True, eos: bool = False):
|
||||
t = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
|
||||
return t
|
||||
|
||||
def decode(self, t: List[int], **kwargs) -> str:
|
||||
return self.tokenizer.decode(t, **kwargs)
|
||||
|
||||
def process_one(
|
||||
self,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image.Image] = None,
|
||||
apply_sft_format: bool = False,
|
||||
inference_mode: bool = True,
|
||||
system_prompt: str = "",
|
||||
max_req_input_len: int = -1,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
conversations (List[Dict]): conversations with a list of messages;
|
||||
images (List[ImageType]): the list of images;
|
||||
apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
|
||||
if conversations is not None, then it will always apply the SFT format to conversations;
|
||||
inference_mode (bool): if True, then remove the last eos token;
|
||||
system_prompt (str): the system prompt;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- target_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
assert (
|
||||
prompt is None or conversations is None
|
||||
), "prompt and conversations cannot be used at the same time."
|
||||
|
||||
(
|
||||
tokenized_str,
|
||||
masked_tokenized_str,
|
||||
images_list,
|
||||
images_seq_mask,
|
||||
images_spatial_crop,
|
||||
) = self.format_messages_v2(conversations, images, max_req_input_len)
|
||||
|
||||
assert (
|
||||
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
|
||||
), (
|
||||
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
|
||||
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
|
||||
)
|
||||
|
||||
input_ids = torch.LongTensor(tokenized_str)
|
||||
target_ids = torch.LongTensor(masked_tokenized_str)
|
||||
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
||||
|
||||
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
|
||||
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
|
||||
self.ignore_id
|
||||
)
|
||||
input_ids[input_ids < 0] = self.pad_id
|
||||
|
||||
if inference_mode:
|
||||
assert input_ids[-1] == self.eos_id
|
||||
input_ids = input_ids[:-1]
|
||||
target_ids = target_ids[:-1]
|
||||
images_seq_mask = images_seq_mask[:-1]
|
||||
|
||||
if len(images_list) == 0:
|
||||
images = torch.zeros((1, 3, self.image_size, self.image_size))
|
||||
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
|
||||
else:
|
||||
images = torch.stack(images_list, dim=0)
|
||||
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
||||
|
||||
images_spatial_crop = torch.stack(
|
||||
[images_spatial_crop], dim=0
|
||||
) # stack the tensor to make it a batch of 1
|
||||
|
||||
prepare = VLChatProcessorOutput(
|
||||
input_ids=input_ids,
|
||||
target_ids=target_ids,
|
||||
pixel_values=images,
|
||||
images_seq_mask=images_seq_mask,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image.Image] = None,
|
||||
apply_sft_format: bool = False,
|
||||
inference_mode: bool = True,
|
||||
system_prompt: str = "",
|
||||
max_req_input_len: int = -1,
|
||||
**kwargs,
|
||||
):
|
||||
prepare = self.process_one(
|
||||
prompt=prompt,
|
||||
conversations=conversations,
|
||||
images=images,
|
||||
apply_sft_format=apply_sft_format,
|
||||
inference_mode=inference_mode,
|
||||
system_prompt=system_prompt,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def find_all_indices(self, messages, target_value):
|
||||
indices = []
|
||||
for index, item in enumerate(messages):
|
||||
if item == target_value:
|
||||
indices.append(index)
|
||||
return indices
|
||||
|
||||
def tokenize_with_images(
|
||||
self,
|
||||
conversation: str,
|
||||
images: List[Image.Image],
|
||||
bos: bool = True,
|
||||
eos: bool = True,
|
||||
cropping: bool = True,
|
||||
max_req_input_len: int = -1,
|
||||
):
|
||||
"""Tokenize text with <image> tags."""
|
||||
images_list, images_seq_mask, images_spatial_crop = [], [], []
|
||||
text_splits = conversation.split(self.image_token)
|
||||
tokenized_str = []
|
||||
for text_sep, image in zip(text_splits, images):
|
||||
"""encode text_sep"""
|
||||
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
|
||||
tokenized_str += tokenized_sep
|
||||
images_seq_mask += [False] * len(tokenized_sep)
|
||||
|
||||
"""select best resolution for anyres"""
|
||||
if cropping:
|
||||
best_width, best_height = select_best_resolution(
|
||||
image.size, self.candidate_resolutions
|
||||
)
|
||||
else:
|
||||
best_width, best_height = self.image_size, self.image_size
|
||||
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
|
||||
|
||||
"""process the global view"""
|
||||
global_view = ImageOps.pad(
|
||||
image,
|
||||
(self.image_size, self.image_size),
|
||||
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
||||
)
|
||||
images_list.append(self.image_transform(global_view))
|
||||
|
||||
"""process the local views"""
|
||||
local_view = ImageOps.pad(
|
||||
image,
|
||||
(best_width, best_height),
|
||||
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
||||
)
|
||||
for i in range(0, best_height, self.image_size):
|
||||
for j in range(0, best_width, self.image_size):
|
||||
images_list.append(
|
||||
self.image_transform(
|
||||
local_view.crop(
|
||||
(j, i, j + self.image_size, i + self.image_size)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
"""record height / width crop num"""
|
||||
num_width_tiles, num_height_tiles = (
|
||||
best_width // self.image_size,
|
||||
best_height // self.image_size,
|
||||
)
|
||||
images_spatial_crop.append([num_width_tiles, num_height_tiles])
|
||||
|
||||
"""add image tokens"""
|
||||
h = w = math.ceil(
|
||||
(self.image_size // self.patch_size) / self.downsample_ratio
|
||||
)
|
||||
# global views tokens h * (w + 1), 1 is for line separator
|
||||
tokenized_image = [self.image_token_id] * h * (w + 1)
|
||||
# add a separator between global and local views
|
||||
tokenized_image += [self.image_token_id]
|
||||
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
|
||||
tokenized_image += (
|
||||
[self.image_token_id]
|
||||
* (num_height_tiles * h)
|
||||
* (num_width_tiles * w + 1)
|
||||
)
|
||||
|
||||
tokenized_str += tokenized_image
|
||||
images_seq_mask += [True] * len(tokenized_image)
|
||||
# print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
|
||||
|
||||
"""process the last text split"""
|
||||
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
|
||||
# deal with video, limit with request len
|
||||
if max_req_input_len > -1:
|
||||
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
|
||||
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
|
||||
tokenized_str = tokenized_str[:rest]
|
||||
images_seq_mask = images_seq_mask[:rest]
|
||||
tokenized_str += tokenized_sep
|
||||
images_seq_mask += [False] * len(tokenized_sep)
|
||||
|
||||
"""add the bos and eos tokens"""
|
||||
if bos:
|
||||
tokenized_str = [self.bos_id] + tokenized_str
|
||||
images_seq_mask = [False] + images_seq_mask
|
||||
if eos:
|
||||
tokenized_str = tokenized_str + [self.eos_id]
|
||||
images_seq_mask = images_seq_mask + [False]
|
||||
|
||||
assert len(tokenized_str) == len(
|
||||
images_seq_mask
|
||||
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
||||
|
||||
return tokenized_str, images_list, images_seq_mask, images_spatial_crop
|
||||
|
||||
|
||||
class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
|
||||
model_type: str = "vision"
|
||||
|
||||
model_name: str = "siglip_large_patch16_384"
|
||||
image_size: int = 384
|
||||
patch_size: int = 16
|
||||
width: int = 1024
|
||||
layers: int = 24
|
||||
heads: int = 16
|
||||
mlp_ratio: int = 4
|
||||
global_pool: str = "map"
|
||||
ignore_head: bool = True
|
||||
class_token: bool = False
|
||||
num_classes: int = 0
|
||||
use_checkpoint: bool = False
|
||||
weight_init: str = "skip"
|
||||
deterministic: bool = False
|
||||
num_recomputing_layers: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "siglip_large_patch16_384",
|
||||
image_size: int = 384,
|
||||
patch_size: int = 16,
|
||||
width: int = 1024,
|
||||
layers: int = 24,
|
||||
heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
global_pool: str = "map",
|
||||
ignore_head: bool = True,
|
||||
class_token: bool = False,
|
||||
num_classes: int = 0,
|
||||
use_checkpoint: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.heads = heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.global_pool = global_pool
|
||||
self.ignore_head = ignore_head
|
||||
self.class_token = class_token
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
|
||||
model_type = "mlp_projector"
|
||||
projector_type: str = "downsample_mlp_gelu"
|
||||
input_dim: int = 1152
|
||||
n_embed: int = 2048
|
||||
depth: int = 2
|
||||
mlp_ratio: int = 1
|
||||
downsample_ratio: int = 2
|
||||
token_pooling: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
projector_type: str = "downsample_mlp_gelu",
|
||||
input_dim: int = 1152,
|
||||
n_embed: int = 2048,
|
||||
depth: int = 2,
|
||||
mlp_ratio: int = 1,
|
||||
downsample_ratio: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
self.projector_type = projector_type
|
||||
self.input_dim = input_dim
|
||||
self.n_embed = n_embed
|
||||
self.depth = depth
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class DeepseekV2Config(PretrainedConfig):
|
||||
|
||||
model_type = "deepseek_v2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=102400,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
moe_intermediate_size=1407,
|
||||
num_hidden_layers=30,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
n_shared_experts=None,
|
||||
n_routed_experts=None,
|
||||
ep_size=1,
|
||||
routed_scaling_factor=1.0,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
topk_method="gready",
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
num_experts_per_tok=None,
|
||||
moe_layer_freq=1,
|
||||
first_k_dense_replace=0,
|
||||
norm_topk_prob=False,
|
||||
scoring_func="softmax",
|
||||
aux_loss_alpha=0.001,
|
||||
seq_aux=True,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=100000,
|
||||
eos_token_id=100001,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
use_mla=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.ep_size = ep_size
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.topk_method = topk_method
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.scoring_func = scoring_func
|
||||
self.aux_loss_alpha = aux_loss_alpha
|
||||
self.seq_aux = seq_aux
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = float(rms_norm_eps)
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.use_mla = use_mla
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekVL2Config(PretrainedConfig):
|
||||
model_type = "deepseek_vl_v2"
|
||||
vision_config: DeepseekVL2VisionEncoderConfig
|
||||
projector_config: DeepseekVL2MlpProjectorConfig
|
||||
language_config: DeepseekV2Config
|
||||
|
||||
tile_tag: str = "2D"
|
||||
global_view_pos: str = "head"
|
||||
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tile_tag: str = "tile_tag",
|
||||
global_view_pos: str = "head",
|
||||
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
vision_config = kwargs.get("vision_config", {})
|
||||
self.vision_config = DeepseekVL2VisionEncoderConfig(**vision_config)
|
||||
|
||||
projector_config = kwargs.get("projector_config", {})
|
||||
self.projector_config = DeepseekVL2MlpProjectorConfig(**projector_config)
|
||||
|
||||
language_config = kwargs.get("language_config", {})
|
||||
if isinstance(language_config, DeepseekV2Config):
|
||||
self.language_config = language_config
|
||||
else:
|
||||
self.language_config = DeepseekV2Config(**language_config)
|
||||
|
||||
self.tile_tag = tile_tag
|
||||
self.global_view_pos = global_view_pos
|
||||
self.candidate_resolutions = candidate_resolutions
|
||||
self.architectures = ["DeepseekVL2ForCausalLM"]
|
||||
|
||||
|
||||
AutoProcessor.register(DeepseekVL2Config, DeepseekVLV2Processor)
|
||||
17
python/sglang/srt/configs/device_config.py
Normal file
17
python/sglang/srt/configs/device_config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
||||
self.device_type = device
|
||||
else:
|
||||
raise RuntimeError(f"Not supported device type: {device}")
|
||||
self.device = torch.device(self.device_type)
|
||||
195
python/sglang/srt/configs/exaone.py
Normal file
195
python/sglang/srt/configs/exaone.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved.
|
||||
# Copyright 2024 The LG CNS AI Engineering Team.
|
||||
# Copyright 2023-2024 SGLang Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" EXAONE model configuration """
|
||||
from typing import Any, Dict
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {}
|
||||
|
||||
|
||||
# ruff: noqa: E501
|
||||
class ExaoneConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to
|
||||
instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the Exaone
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 102400):
|
||||
Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model.
|
||||
Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of
|
||||
:class:`~transformers.EXAONEModel`.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 2048):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_layers (:obj:`int`, `optional`, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (:obj:`int`, `optional`):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
rope_theta (:obj:`float`, `optional`, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (:obj:`Dict`, `optional`):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (:obj:`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (:obj:`float`, `optional`):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (:obj:`int`, `optional`):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (:obj:`float`, `optional`):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (:obj:`float`, `optional`):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (:obj:`float`, `optional`):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (:obj:`List[float]`, `optional`):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (:obj:`List[float]`, `optional`):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (:obj:`float`, `optional`):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (:obj:`float`, `optional`):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
embed_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if ``configs.is_decoder=True``.
|
||||
bos_token_id (:obj:`int`, `optional`, defaults to 0):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (:obj:`int`, `optional`, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to tie weight embeddings
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import EXAONEModel, ExaoneConfig
|
||||
|
||||
>>> # Initializing a EXAONE configuration
|
||||
>>> configuration = ExaoneConfig()
|
||||
|
||||
>>> # Initializing a model from configuration
|
||||
>>> model = EXAONEModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.configs
|
||||
"""
|
||||
|
||||
model_type = "exaone"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_hidden_layers": "num_layers"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=102400,
|
||||
max_position_embeddings=2048,
|
||||
hidden_size=2048,
|
||||
num_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
intermediate_size=None,
|
||||
activation_function="silu",
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
embed_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
**kwargs
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_layers
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
if intermediate_size:
|
||||
self.intermediate_size = intermediate_size
|
||||
else:
|
||||
self.intermediate_size = hidden_size * 4
|
||||
self.activation_function = activation_function
|
||||
self.embed_dropout = embed_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs
|
||||
)
|
||||
706
python/sglang/srt/configs/internvl.py
Normal file
706
python/sglang/srt/configs/internvl.py
Normal file
@@ -0,0 +1,706 @@
|
||||
import copy
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
from transformers import (
|
||||
TOKENIZER_MAPPING,
|
||||
GptOssConfig,
|
||||
LlamaConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
Qwen2Config,
|
||||
Qwen3Config,
|
||||
Qwen3MoeConfig,
|
||||
)
|
||||
|
||||
from sglang.utils import logger
|
||||
|
||||
# Copied from: https://github.com/OpenGVLab/InternVL/blob/34a81000402bf8f716bab8c9b57aff1f6b436bd0/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py#L21
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {}
|
||||
|
||||
|
||||
# Modified from transformers.model.llama.configuration_llama.LlamaConfig
|
||||
class InternLM2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
|
||||
an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`InternLM2Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
Example:
|
||||
|
||||
"""
|
||||
|
||||
model_type = "internlm2"
|
||||
_auto_class = "AutoConfig"
|
||||
|
||||
def __init__( # pylint: disable=W0102
|
||||
self,
|
||||
vocab_size=103168,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
bias=True,
|
||||
rope_theta=10000,
|
||||
rope_scaling=None,
|
||||
attn_implementation="eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.bias = bias
|
||||
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
|
||||
self.attn_implementation = attn_implementation
|
||||
if self.attn_implementation is None:
|
||||
self.attn_implementation = "eager"
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if (
|
||||
rope_scaling_factor is None
|
||||
or not isinstance(rope_scaling_factor, (float, int))
|
||||
or rope_scaling_factor < 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s factor field must be a float|int >= 1, got {rope_scaling_factor=}, {type(rope_scaling_factor)=}"
|
||||
)
|
||||
if isinstance(rope_scaling_factor, int):
|
||||
rope_scaling_factor = float(rope_scaling_factor)
|
||||
|
||||
|
||||
class InternVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
||||
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of color channels in the input images (e.g., 3 for RGB).
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries and values in the self-attention layers.
|
||||
hidden_size (`int`, *optional*, defaults to 3200):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_attention_heads (`int`, *optional*, defaults to 25):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 12800):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
qk_normalization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the queries and keys in the self-attention layers.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use flash attention mechanism.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
The epsilon used by the layer normalization layers.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate for stochastic depth.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float`, *optional*, defaults to 0.1):
|
||||
A factor for layer scale.
|
||||
"""
|
||||
|
||||
model_type = "intern_vit_6b"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels=3,
|
||||
patch_size=14,
|
||||
image_size=224,
|
||||
qkv_bias=False,
|
||||
hidden_size=3200,
|
||||
num_attention_heads=25,
|
||||
intermediate_size=12800,
|
||||
qk_normalization=True,
|
||||
num_hidden_layers=48,
|
||||
use_flash_attn=True,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-6,
|
||||
dropout=0.0,
|
||||
drop_path_rate=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qk_normalization = qk_normalization
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
if "vision_config" in config_dict:
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class InternVLChatConfig(PretrainedConfig):
|
||||
model_type = "internvl_chat"
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
llm_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-1,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
ps_version="v1",
|
||||
min_dynamic_patch=1,
|
||||
max_dynamic_patch=6,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"architectures": ["InternVisionModel"]}
|
||||
logger.info(
|
||||
"vision_config is None. Initializing the InternVisionConfig with default values."
|
||||
)
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = {"architectures": ["InternLM2ForCausalLM"]}
|
||||
logger.info(
|
||||
"llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
|
||||
self.vision_config = InternVisionConfig(**vision_config)
|
||||
if llm_config.get("architectures")[0] == "LlamaForCausalLM":
|
||||
self.llm_config = LlamaConfig(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM":
|
||||
self.llm_config = InternLM2Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
||||
self.llm_config = Qwen2Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
||||
self.llm_config = Qwen3MoeConfig(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
|
||||
self.llm_config = Qwen3Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
|
||||
self.llm_config = GptOssConfig(**llm_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported architecture: {}".format(
|
||||
llm_config.get("architectures")[0]
|
||||
)
|
||||
)
|
||||
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.ps_version = ps_version # pixel shuffle version
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
||||
|
||||
self.hidden_size = self.llm_config.hidden_size
|
||||
# By default, we use tie_word_embeddings=False for models of all sizes.
|
||||
self.tie_word_embeddings = False
|
||||
self.llm_config.tie_word_embeddings = self.tie_word_embeddings
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["llm_config"] = self.llm_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["ps_version"] = self.ps_version
|
||||
output["min_dynamic_patch"] = self.min_dynamic_patch
|
||||
output["max_dynamic_patch"] = self.max_dynamic_patch
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
|
||||
# class InternLM2TokenizerFast(PreTrainedTokenizerFast):
|
||||
# vocab_files_names = VOCAB_FILES_NAMES
|
||||
# slow_tokenizer_class = InternLM2Tokenizer
|
||||
# padding_side = 'left'
|
||||
# model_input_names = ['input_ids', 'attention_mask']
|
||||
# _auto_class = 'AutoTokenizer'
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# vocab_file,
|
||||
# unk_token='<unk>',
|
||||
# bos_token='<s>',
|
||||
# eos_token='</s>',
|
||||
# pad_token='</s>',
|
||||
# sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# add_bos_token=True,
|
||||
# add_eos_token=False,
|
||||
# decode_with_prefix_space=False,
|
||||
# clean_up_tokenization_spaces=False,
|
||||
# **kwargs,
|
||||
# ):
|
||||
# super().__init__(
|
||||
# vocab_file=vocab_file,
|
||||
# unk_token=unk_token,
|
||||
# bos_token=bos_token,
|
||||
# eos_token=eos_token,
|
||||
# pad_token=pad_token,
|
||||
# sp_model_kwargs=sp_model_kwargs,
|
||||
# add_bos_token=add_bos_token,
|
||||
# add_eos_token=add_eos_token,
|
||||
# decode_with_prefix_space=decode_with_prefix_space,
|
||||
# clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
# **kwargs,
|
||||
# )
|
||||
# self._add_bos_token = add_bos_token
|
||||
# self._add_eos_token = add_eos_token
|
||||
# self.update_post_processor()
|
||||
# self.vocab_file = vocab_file
|
||||
#
|
||||
# @property
|
||||
# def can_save_slow_tokenizer(self) -> bool:
|
||||
# return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
||||
#
|
||||
# def update_post_processor(self):
|
||||
# """
|
||||
# Updates the underlying post processor with the current `bos_token` and `eos_token`.
|
||||
# """
|
||||
# bos = self.bos_token
|
||||
# bos_token_id = self.bos_token_id
|
||||
# if bos is None and self.add_bos_token:
|
||||
# raise ValueError('add_bos_token = True but bos_token = None')
|
||||
#
|
||||
# eos = self.eos_token
|
||||
# eos_token_id = self.eos_token_id
|
||||
# if eos is None and self.add_eos_token:
|
||||
# raise ValueError('add_eos_token = True but eos_token = None')
|
||||
#
|
||||
# single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
|
||||
# pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
|
||||
#
|
||||
# special_tokens = []
|
||||
# if self.add_bos_token:
|
||||
# special_tokens.append((bos, bos_token_id))
|
||||
# if self.add_eos_token:
|
||||
# special_tokens.append((eos, eos_token_id))
|
||||
# self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||
# single=single, pair=pair, special_tokens=special_tokens
|
||||
# )
|
||||
#
|
||||
# @property
|
||||
# def add_eos_token(self):
|
||||
# return self._add_eos_token
|
||||
#
|
||||
# @property
|
||||
# def add_bos_token(self):
|
||||
# return self._add_bos_token
|
||||
#
|
||||
# @add_eos_token.setter
|
||||
# def add_eos_token(self, value):
|
||||
# self._add_eos_token = value
|
||||
# self.update_post_processor()
|
||||
#
|
||||
# @add_bos_token.setter
|
||||
# def add_bos_token(self, value):
|
||||
# self._add_bos_token = value
|
||||
# self.update_post_processor()
|
||||
#
|
||||
# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
# if not self.can_save_slow_tokenizer:
|
||||
# raise ValueError(
|
||||
# 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
|
||||
# 'tokenizer.'
|
||||
# )
|
||||
#
|
||||
# if not os.path.isdir(save_directory):
|
||||
# logger.error(f'Vocabulary path ({save_directory}) should be a directory')
|
||||
# return
|
||||
# out_vocab_file = os.path.join(
|
||||
# save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
|
||||
# )
|
||||
#
|
||||
# if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
# copyfile(self.vocab_file, out_vocab_file)
|
||||
#
|
||||
# return (out_vocab_file,)
|
||||
|
||||
|
||||
# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
|
||||
class InternLM2Tokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
_auto_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
pad_token="</s>",
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
decode_with_prefix_space=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**kwargs,
|
||||
):
|
||||
print("register succeed")
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
self.vocab_file = vocab_file
|
||||
self.add_bos_token = add_bos_token
|
||||
self.add_eos_token = add_eos_token
|
||||
self.decode_with_prefix_space = decode_with_prefix_space
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(vocab_file)
|
||||
self._no_prefix_space_tokens = None
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def no_prefix_space_tokens(self):
|
||||
if self._no_prefix_space_tokens is None:
|
||||
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
|
||||
self._no_prefix_space_tokens = {
|
||||
i for i, tok in enumerate(vocab) if not tok.startswith("▁")
|
||||
}
|
||||
return self._no_prefix_space_tokens
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Returns vocab size"""
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> Optional[int]:
|
||||
return self.sp_model.bos_id()
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> Optional[int]:
|
||||
return self.sp_model.eos_id()
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
|
||||
def _maybe_add_prefix_space(self, tokens, decoded):
|
||||
if tokens and tokens[0] not in self.no_prefix_space_tokens:
|
||||
return " " + decoded
|
||||
else:
|
||||
return decoded
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
out_string = self.clean_up_tokenization(out_string)
|
||||
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
|
||||
return out_string[1:]
|
||||
|
||||
def save_vocabulary(
|
||||
self, save_directory, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "")
|
||||
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||
out_vocab_file
|
||||
) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
if self.add_bos_token:
|
||||
bos_token_ids = [self.bos_token_id]
|
||||
else:
|
||||
bos_token_ids = []
|
||||
|
||||
output = bos_token_ids + token_ids_0
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output = output + token_ids_1
|
||||
|
||||
if self.add_eos_token:
|
||||
output = output + [self.eos_token_id]
|
||||
|
||||
return output
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0,
|
||||
token_ids_1=token_ids_1,
|
||||
already_has_special_tokens=True,
|
||||
)
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
||||
use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
"""
|
||||
eos = [self.eos_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||
|
||||
|
||||
TOKENIZER_MAPPING.register(
|
||||
InternVLChatConfig, (InternLM2Tokenizer, None), exist_ok=True
|
||||
)
|
||||
634
python/sglang/srt/configs/janus_pro.py
Normal file
634
python/sglang/srt/configs/janus_pro.py
Normal file
@@ -0,0 +1,634 @@
|
||||
# Adapted from:
|
||||
# https://github.com/deepseek-ai/Janus/tree/main/janus/models
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
LlamaConfig,
|
||||
LlamaTokenizerFast,
|
||||
PretrainedConfig,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from transformers.image_utils import to_numpy_array
|
||||
|
||||
from sglang.srt.configs.utils import register_image_processor, register_processor
|
||||
from sglang.srt.multimodal.mm_utils import expand2square
|
||||
|
||||
|
||||
class DictToObject(dict):
|
||||
def __init__(self, dictionary):
|
||||
super(self).__init__(dictionary)
|
||||
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToObject(value)
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class VisionConfig(PretrainedConfig):
|
||||
model_type = "vision"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
class GenAlignerConfig(PretrainedConfig):
|
||||
model_type = "gen_aligner"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
class GenHeadConfig(PretrainedConfig):
|
||||
model_type = "gen_head"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
class AlignerConfig(PretrainedConfig):
|
||||
model_type = "aligner"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
class GenVisionConfig(PretrainedConfig):
|
||||
model_type = "gen_vision"
|
||||
cls: str = ""
|
||||
params = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = kwargs.get("params", {})
|
||||
|
||||
|
||||
@dataclass
|
||||
class SigLIPVisionCfg:
|
||||
width: int = 1152
|
||||
layers: Union[Tuple[int, int, int, int], int] = 27
|
||||
heads: int = 16
|
||||
patch_size: int = 14
|
||||
image_size: Union[Tuple[int, int], int] = 336
|
||||
global_pool: str = "map"
|
||||
mlp_ratio: float = 3.7362
|
||||
class_token: bool = False
|
||||
num_classes: int = 0
|
||||
use_checkpoint: bool = False
|
||||
|
||||
|
||||
class MultiModalityConfig(PretrainedConfig):
|
||||
model_type = "multi_modality"
|
||||
vision_config: VisionConfig
|
||||
aligner_config: AlignerConfig
|
||||
|
||||
gen_vision_config: GenVisionConfig
|
||||
gen_aligner_config: GenAlignerConfig
|
||||
gen_head_config: GenHeadConfig
|
||||
|
||||
language_config: LlamaConfig
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
vision_config = kwargs.get("vision_config", {})
|
||||
self.vision_config = VisionConfig(**vision_config)
|
||||
|
||||
aligner_config = kwargs.get("aligner_config", {})
|
||||
self.aligner_config = AlignerConfig(**aligner_config)
|
||||
|
||||
gen_vision_config = kwargs.get("gen_vision_config", {})
|
||||
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
|
||||
|
||||
gen_aligner_config = kwargs.get("gen_aligner_config", {})
|
||||
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
|
||||
|
||||
gen_head_config = kwargs.get("gen_head_config", {})
|
||||
self.gen_head_config = GenHeadConfig(**gen_head_config)
|
||||
|
||||
language_config = kwargs.get("language_config", {})
|
||||
if isinstance(language_config, LlamaConfig):
|
||||
self.language_config = language_config
|
||||
else:
|
||||
self.language_config = LlamaConfig(**language_config)
|
||||
|
||||
|
||||
class VLMImageProcessor(BaseImageProcessor):
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
min_size: int = 14,
|
||||
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073,
|
||||
),
|
||||
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711,
|
||||
),
|
||||
rescale_factor: float = 1.0 / 255.0,
|
||||
do_normalize: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.rescale_factor = rescale_factor
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.min_size = min_size
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
if image_mean is None:
|
||||
self.background_color = (127, 127, 127)
|
||||
else:
|
||||
self.background_color = tuple([int(x * 255) for x in image_mean])
|
||||
|
||||
def resize(self, pil_img: Image) -> np.ndarray:
|
||||
"""
|
||||
|
||||
Args:
|
||||
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
||||
|
||||
Returns:
|
||||
x (np.ndarray): [3, self.image_size, self.image_size]
|
||||
"""
|
||||
|
||||
width, height = pil_img.size
|
||||
max_size = max(width, height)
|
||||
|
||||
size = [
|
||||
max(int(height / max_size * self.image_size), self.min_size),
|
||||
max(int(width / max_size * self.image_size), self.min_size),
|
||||
]
|
||||
|
||||
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
||||
# print(f"orig size = {pil_img.size}, new size = {size}")
|
||||
raise ValueError("Invalid size!")
|
||||
|
||||
def resize(
|
||||
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
|
||||
):
|
||||
if isinstance(size, int):
|
||||
w, h = pil_img.size
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return pil_img
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
size = (ow, oh)
|
||||
else:
|
||||
size = (size[1], size[0])
|
||||
|
||||
return pil_img.resize(
|
||||
size, resample=interpolation, reducing_gap=None if antialias else 3.0
|
||||
)
|
||||
|
||||
pil_img = resize(
|
||||
pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
|
||||
)
|
||||
|
||||
pil_img = expand2square(pil_img, self.background_color)
|
||||
x = to_numpy_array(pil_img)
|
||||
|
||||
# [H, W, 3] -> [3, H, W]
|
||||
x = np.transpose(x, (2, 0, 1))
|
||||
|
||||
return x
|
||||
|
||||
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
||||
# resize and pad to [self.image_size, self.image_size]
|
||||
# then convert from [H, W, 3] to [3, H, W]
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
images: List[np.ndarray] = [self.resize(image) for image in images]
|
||||
images = [image[:3, ...] for image in images]
|
||||
|
||||
# rescale from [0, 255] -> [0, 1]
|
||||
images = [
|
||||
self.rescale(
|
||||
image=image,
|
||||
scale=self.rescale_factor,
|
||||
input_data_format="channels_first",
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
# normalize
|
||||
if self.do_normalize:
|
||||
images = [
|
||||
self.normalize(
|
||||
image=image,
|
||||
mean=self.image_mean,
|
||||
std=self.image_std,
|
||||
input_data_format="channels_first",
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
@property
|
||||
def default_shape(self):
|
||||
return [3, self.image_size, self.image_size]
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__dict__[item]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLChatProcessorOutput(DictOutput):
|
||||
sft_format: str
|
||||
input_ids: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
num_image_tokens: torch.IntTensor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedVLChatProcessorOutput(DictOutput):
|
||||
sft_format: List[str]
|
||||
input_ids: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
attention_mask: torch.Tensor
|
||||
images_seq_mask: torch.BoolTensor
|
||||
images_emb_mask: torch.BoolTensor
|
||||
|
||||
|
||||
# FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
|
||||
# hence AutoProcessor registration would not be affective in some cases
|
||||
class VLChatProcessor(ProcessorMixin):
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: VLMImageProcessor,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
image_tag: str = "<image_placeholder>",
|
||||
image_start_tag: str = "<begin_of_image>",
|
||||
image_end_tag: str = "<end_of_image>",
|
||||
pad_tag: str = "<|▁pad▁|>",
|
||||
num_image_tokens: int = 576,
|
||||
add_special_token: bool = False,
|
||||
sft_format: str = "deepseek",
|
||||
mask_prompt: bool = True,
|
||||
ignore_id: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
image_id = self.tokenizer.vocab.get(image_tag)
|
||||
if image_id is None:
|
||||
special_tokens = [image_tag]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
# print(f"Add image tag = {image_tag} to the tokenizer")
|
||||
|
||||
self.image_tag = image_tag
|
||||
self.image_start_tag = image_start_tag
|
||||
self.image_end_tag = image_end_tag
|
||||
self.pad_tag = pad_tag
|
||||
|
||||
self.num_image_tokens = num_image_tokens
|
||||
self.add_special_token = add_special_token
|
||||
self.sft_format = sft_format
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
super().__init__(
|
||||
image_processor,
|
||||
tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def image_token(self):
|
||||
return self.image_tag
|
||||
|
||||
@property
|
||||
def image_id(self) -> int:
|
||||
image_id = self.tokenizer.vocab.get(self.image_tag)
|
||||
return image_id
|
||||
|
||||
@property
|
||||
def image_start_id(self):
|
||||
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
|
||||
return image_start_id
|
||||
|
||||
@property
|
||||
def image_end_id(self):
|
||||
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
|
||||
return image_end_id
|
||||
|
||||
@property
|
||||
def image_start_token(self):
|
||||
return self.image_start_tag
|
||||
|
||||
@property
|
||||
def image_end_token(self):
|
||||
return self.image_end_tag
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
pad_id = self.tokenizer.vocab.get(self.pad_tag)
|
||||
return pad_id
|
||||
|
||||
def add_image_token(
|
||||
self,
|
||||
image_indices: List[int],
|
||||
input_ids: torch.LongTensor,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
||||
input_ids (torch.LongTensor): [N]
|
||||
|
||||
Returns:
|
||||
input_ids (torch.LongTensor): [N + image tokens]
|
||||
num_image_tokens (torch.IntTensor): [n_images]
|
||||
"""
|
||||
|
||||
input_slices = []
|
||||
|
||||
start = 0
|
||||
for index in image_indices:
|
||||
if self.add_special_token:
|
||||
end = index + 1
|
||||
else:
|
||||
end = index
|
||||
|
||||
# original text tokens
|
||||
input_slices.append(input_ids[start:end])
|
||||
|
||||
# add boi, image tokens, eoi and set the mask as False
|
||||
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
|
||||
input_slices.append(
|
||||
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
||||
)
|
||||
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
|
||||
start = index + 1
|
||||
|
||||
# the left part
|
||||
input_slices.append(input_ids[start:])
|
||||
|
||||
# concat all slices
|
||||
input_ids = torch.cat(input_slices, dim=0)
|
||||
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
|
||||
|
||||
return input_ids, num_image_tokens
|
||||
|
||||
def process_one(
|
||||
self,
|
||||
prompt: str = None,
|
||||
images: List[Image] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
images (List[ImageType]): the list of images;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- target_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
sft_format = prompt
|
||||
# tokenize
|
||||
input_ids = self.tokenizer.encode(sft_format)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
|
||||
# add image tokens to the input_ids
|
||||
image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)
|
||||
image_indices = image_token_mask.nonzero()
|
||||
input_ids, num_image_tokens = self.add_image_token(
|
||||
image_indices=image_indices,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
# load images
|
||||
images_outputs = self.image_processor(images, return_tensors="pt")
|
||||
|
||||
prepare = VLChatProcessorOutput(
|
||||
sft_format=sft_format,
|
||||
input_ids=input_ids,
|
||||
pixel_values=images_outputs.pixel_values,
|
||||
num_image_tokens=num_image_tokens,
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image] = None,
|
||||
force_batchify: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
conversations (List[Dict]): conversations with a list of messages;
|
||||
images (List[ImageType]): the list of images;
|
||||
force_batchify (bool): force batchify the inputs;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
prepare = self.process_one(
|
||||
prompt=prompt, conversations=conversations, images=images
|
||||
)
|
||||
|
||||
if force_batchify:
|
||||
prepare = self.batchify([prepare])
|
||||
|
||||
return prepare
|
||||
|
||||
def batchify(
|
||||
self, prepare_list: List[VLChatProcessorOutput]
|
||||
) -> BatchedVLChatProcessorOutput:
|
||||
"""
|
||||
Preprocesses the inputs for multimodal inference.
|
||||
|
||||
Args:
|
||||
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
||||
|
||||
Returns:
|
||||
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
||||
"""
|
||||
|
||||
batch_size = len(prepare_list)
|
||||
sft_format = []
|
||||
n_images = []
|
||||
seq_lens = []
|
||||
for prepare in prepare_list:
|
||||
n_images.append(len(prepare.num_image_tokens))
|
||||
seq_lens.append(len(prepare))
|
||||
|
||||
input_token_max_len = max(seq_lens)
|
||||
max_n_images = max(1, max(n_images))
|
||||
|
||||
batched_input_ids = torch.full(
|
||||
(batch_size, input_token_max_len), self.pad_id
|
||||
).long() # FIXME
|
||||
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
||||
batched_pixel_values = torch.zeros(
|
||||
(batch_size, max_n_images, *self.image_processor.default_shape)
|
||||
).float()
|
||||
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
||||
batched_images_emb_mask = torch.zeros(
|
||||
(batch_size, max_n_images, self.num_image_tokens)
|
||||
).bool()
|
||||
|
||||
for i, prepare in enumerate(prepare_list):
|
||||
input_ids = prepare.input_ids
|
||||
seq_len = len(prepare)
|
||||
n_image = len(prepare.num_image_tokens)
|
||||
# left-padding
|
||||
batched_attention_mask[i, -seq_len:] = 1
|
||||
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
||||
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
|
||||
|
||||
if n_image > 0:
|
||||
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
||||
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
|
||||
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
||||
|
||||
sft_format.append(prepare.sft_format)
|
||||
|
||||
batched_prepares = BatchedVLChatProcessorOutput(
|
||||
input_ids=batched_input_ids,
|
||||
attention_mask=batched_attention_mask,
|
||||
pixel_values=batched_pixel_values,
|
||||
images_seq_mask=batched_images_seq_mask,
|
||||
images_emb_mask=batched_images_emb_mask,
|
||||
sft_format=sft_format,
|
||||
)
|
||||
|
||||
return batched_prepares
|
||||
|
||||
|
||||
class VLMImageProcessorConfig(PretrainedConfig):
|
||||
model_type = "deepseek_vlm"
|
||||
image_size: int
|
||||
min_size: int
|
||||
image_mean: Union[Tuple[float, float, float], List[float]]
|
||||
image_std: Union[Tuple[float, float, float], List[float]]
|
||||
rescale_factor: float
|
||||
do_normalize: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
min_size: int = 14,
|
||||
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073,
|
||||
),
|
||||
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711,
|
||||
),
|
||||
rescale_factor: float = 1.0 / 255.0,
|
||||
do_normalize: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_size = image_size
|
||||
self.min_size = min_size
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
register_processor(MultiModalityConfig, VLChatProcessor)
|
||||
register_image_processor(MultiModalityConfig, VLMImageProcessor)
|
||||
38
python/sglang/srt/configs/kimi_vl.py
Normal file
38
python/sglang/srt/configs/kimi_vl.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
|
||||
|
||||
class KimiVLConfig(PretrainedConfig):
|
||||
model_type = "kimi_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
|
||||
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
|
||||
ignore_index: int = -100,
|
||||
media_placeholder_token_id: int = 163605,
|
||||
pad_token_id: int = 0,
|
||||
**kwargs
|
||||
):
|
||||
if vision_config is None:
|
||||
vision_config = MoonViTConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = MoonViTConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is None:
|
||||
text_config = DeepseekV2Config()
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = DeepseekV2Config(**text_config)
|
||||
self.text_config = text_config
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.media_placeholder_token_id = media_placeholder_token_id
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
32
python/sglang/srt/configs/kimi_vl_moonvit.py
Normal file
32
python/sglang/srt/configs/kimi_vl_moonvit.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class MoonViTConfig(PretrainedConfig):
|
||||
model_type = "moonvit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
init_pos_emb_height: int = 64,
|
||||
init_pos_emb_width: int = 64,
|
||||
num_attention_heads: int = 16,
|
||||
num_hidden_layers: int = 27,
|
||||
hidden_size: int = 1152,
|
||||
intermediate_size: int = 4304,
|
||||
merge_kernel_size: tuple[int, int] = (2, 2),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.patch_size = patch_size
|
||||
# Positional embedding config
|
||||
self.init_pos_emb_height = init_pos_emb_height
|
||||
self.init_pos_emb_width = init_pos_emb_width
|
||||
# Transformer config
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
# Patch merger config
|
||||
self.merge_kernel_size = merge_kernel_size
|
||||
89
python/sglang/srt/configs/load_config.py
Normal file
89
python/sglang/srt/configs/load_config.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
PT = "pt"
|
||||
SAFETENSORS = "safetensors"
|
||||
NPCACHE = "npcache"
|
||||
DUMMY = "dummy"
|
||||
SHARDED_STATE = "sharded_state"
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
LAYERED = "layered"
|
||||
JAX = "jax"
|
||||
REMOTE = "remote"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
"bitsandbytes" will load nf4 type weights.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
decryption_key_file: If set, decrypts the output files with a password read
|
||||
from this file (after PBKDF2).
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
decryption_key_file: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
||||
self._verify_load_format()
|
||||
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
if not isinstance(self.load_format, str):
|
||||
return
|
||||
|
||||
load_format = self.load_format.lower()
|
||||
self.load_format = LoadFormat(load_format)
|
||||
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f
|
||||
for f in LoadFormat.__members__
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format '{load_format}' is not supported in ROCm. "
|
||||
f"Supported load formats are "
|
||||
f"{rocm_supported_load_format}"
|
||||
)
|
||||
104
python/sglang/srt/configs/longcat_flash.py
Normal file
104
python/sglang/srt/configs/longcat_flash.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||
|
||||
|
||||
class LongcatFlashConfig(PretrainedConfig):
|
||||
model_type = "longcat_flash"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=131072,
|
||||
hidden_size=6144,
|
||||
intermediate_size=None,
|
||||
ffn_hidden_size=12288,
|
||||
expert_ffn_hidden_size=2048,
|
||||
num_layers=28,
|
||||
num_hidden_layers=None,
|
||||
num_attention_heads=64,
|
||||
ep_size=1,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
v_head_dim=128,
|
||||
n_routed_experts=512,
|
||||
moe_topk=12,
|
||||
norm_topk_prob=False,
|
||||
max_position_embeddings=131072,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mla_scale_q_lora=True,
|
||||
mla_scale_kv_lora=True,
|
||||
torch_dtype="bfloat16",
|
||||
params_dtype="bfloat16",
|
||||
rounter_params_dtype="float32",
|
||||
router_bias=False,
|
||||
topk_method=None,
|
||||
routed_scaling_factor=6.0,
|
||||
zero_expert_num=256,
|
||||
zero_expert_type="identity",
|
||||
nextn_use_scmoe=False,
|
||||
num_nextn_predict_layers=1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
torch_dtype=torch_dtype,
|
||||
params_dtype=params_dtype,
|
||||
rounter_params_dtype=rounter_params_dtype,
|
||||
topk_method=topk_method,
|
||||
router_bias=router_bias,
|
||||
nextn_use_scmoe=nextn_use_scmoe,
|
||||
num_nextn_predict_layers=num_nextn_predict_layers,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = (
|
||||
num_hidden_layers if num_hidden_layers is not None else num_layers
|
||||
)
|
||||
self.intermediate_size = (
|
||||
intermediate_size if intermediate_size is not None else ffn_hidden_size
|
||||
)
|
||||
self.moe_intermediate_size = expert_ffn_hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.ep_size = ep_size
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.moe_topk = moe_topk
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mla_scale_q_lora = mla_scale_q_lora
|
||||
self.mla_scale_kv_lora = mla_scale_kv_lora
|
||||
self.zero_expert_num = zero_expert_num
|
||||
self.zero_expert_type = zero_expert_type
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.hidden_act = "silu"
|
||||
811
python/sglang/srt/configs/model_config.py
Normal file
811
python/sglang/srt/configs/model_config.py
Normal file
@@ -0,0 +1,811 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
get_generation_config,
|
||||
get_hf_text_config,
|
||||
get_sparse_attention_config,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AttentionArch(IntEnum):
|
||||
MLA = auto()
|
||||
MHA = auto()
|
||||
|
||||
|
||||
class ModelImpl(str, Enum):
|
||||
AUTO = "auto"
|
||||
SGLANG = "sglang"
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
trust_remote_code: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: str = "{}",
|
||||
is_embedding: Optional[bool] = None,
|
||||
enable_multimodal: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.model_impl = model_impl
|
||||
|
||||
self.maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
kwargs = {}
|
||||
if override_config_file and override_config_file.strip():
|
||||
kwargs["_configuration_file"] = override_config_file.strip()
|
||||
|
||||
self.hf_config = get_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hf_generation_config = get_generation_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
)
|
||||
self.is_hybrid = is_hybrid_model(
|
||||
self.hf_config.architectures,
|
||||
hybrid_kvcache_ratio=hybrid_kvcache_ratio,
|
||||
context_length=context_length,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
if self.is_hybrid is not None:
|
||||
self.swa_attention_layer_ids, self.full_attention_layer_ids = (
|
||||
get_hybrid_layer_ids(
|
||||
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
|
||||
)
|
||||
)
|
||||
|
||||
if enable_multimodal is None:
|
||||
mm_disabled_models = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
if self.hf_config.architectures[0] in mm_disabled_models:
|
||||
enable_multimodal = False
|
||||
logger.info(
|
||||
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
||||
)
|
||||
else:
|
||||
enable_multimodal = True
|
||||
|
||||
if (
|
||||
is_draft_model
|
||||
and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
):
|
||||
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
|
||||
|
||||
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
|
||||
|
||||
if (
|
||||
is_draft_model
|
||||
and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
|
||||
):
|
||||
self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
|
||||
self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
|
||||
|
||||
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
self.hf_config.architectures[0] = "MiMoMTP"
|
||||
if (
|
||||
is_draft_model
|
||||
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
|
||||
):
|
||||
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
|
||||
|
||||
if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
|
||||
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_chunked_prefill_supported = (
|
||||
enable_multimodal
|
||||
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||
)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
# Derive context length
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
if context_length is not None:
|
||||
if context_length > derived_context_len:
|
||||
reason = "Target model's" if is_draft_model else "User-specified"
|
||||
msg = (
|
||||
f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
|
||||
)
|
||||
if (
|
||||
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
|
||||
or is_in_ci() # FIXME: fix this special case
|
||||
):
|
||||
logger.warning(msg)
|
||||
self.context_len = context_length
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
||||
)
|
||||
else:
|
||||
self.context_len = context_length
|
||||
else:
|
||||
self.context_len = derived_context_len
|
||||
|
||||
# Unify the config keys for hf_text_config
|
||||
self.head_dim = getattr(
|
||||
self.hf_text_config,
|
||||
"head_dim",
|
||||
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
||||
)
|
||||
|
||||
# FIXME: temporary special judge for MLA architecture
|
||||
if (
|
||||
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
|
||||
):
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_config.v_head_dim
|
||||
|
||||
# Handle rope scaling with yarn
|
||||
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
if self.hf_config.rope_scaling:
|
||||
mscale_all_dim = self.hf_config.rope_scaling.get(
|
||||
"mscale_all_dim", False
|
||||
)
|
||||
scaling_factor = self.hf_config.rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
|
||||
self.head_dim = 128
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
|
||||
self.hf_text_config, "use_mla", True
|
||||
):
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
||||
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_text_config.v_head_dim
|
||||
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
|
||||
else:
|
||||
if (
|
||||
"MistralModel" in self.hf_config.architectures
|
||||
or "MixtralForCausalLM" in self.hf_config.architectures
|
||||
or "MistralForCausalLM" in self.hf_config.architectures
|
||||
):
|
||||
if getattr(self, "head_dim", None) is None:
|
||||
self.head_dim = (
|
||||
self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
)
|
||||
# In transformers==4.52.3, the head_dim is null in MistralConfig
|
||||
if (
|
||||
not hasattr(self.hf_text_config, "head_dim")
|
||||
or self.hf_text_config.head_dim is None
|
||||
):
|
||||
setattr(self.hf_text_config, "head_dim", self.head_dim)
|
||||
|
||||
self.attention_arch = AttentionArch.MHA
|
||||
|
||||
self.num_attention_heads = self.hf_text_config.num_attention_heads
|
||||
self.num_key_value_heads = getattr(
|
||||
self.hf_text_config, "num_key_value_heads", None
|
||||
)
|
||||
|
||||
# for Dbrx and MPT models
|
||||
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
||||
self.num_key_value_heads = getattr(
|
||||
self.hf_config.attn_config, "kv_n_heads", None
|
||||
)
|
||||
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
self.hidden_size = self.hf_text_config.hidden_size
|
||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
self.num_attention_layers = self.num_hidden_layers
|
||||
if "LongcatFlashForCausalLM" in self.hf_config.architectures:
|
||||
self.num_attention_layers = self.num_hidden_layers * 2
|
||||
self.num_nextn_predict_layers = getattr(
|
||||
self.hf_text_config, "num_nextn_predict_layers", None
|
||||
)
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
# Verify quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Verify dual-chunk attention config
|
||||
self._verify_dual_chunk_attention_config()
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
|
||||
# multimodal
|
||||
self.image_token_id = getattr(
|
||||
self.hf_config, "image_token_id", None
|
||||
) or getattr(self.hf_config, "image_token_index", None)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(
|
||||
server_args: ServerArgs,
|
||||
model_path: str = None,
|
||||
model_revision: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
return ModelConfig(
|
||||
model_path=model_path or server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=model_revision or server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_total_num_attention_heads(self) -> int:
|
||||
return self.num_attention_heads
|
||||
|
||||
def get_num_attention_heads(self, tensor_parallel_size) -> int:
|
||||
total_num_attention_heads = self.num_attention_heads
|
||||
return max(1, total_num_attention_heads // tensor_parallel_size)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
# For GPTBigCode & Falcon:
|
||||
# NOTE: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
# KV heads.
|
||||
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||
new_decoder_arch_falcon = (
|
||||
self.hf_config.model_type in falcon_model_types
|
||||
and getattr(self.hf_config, "new_decoder_architecture", False)
|
||||
)
|
||||
if not new_decoder_arch_falcon and getattr(
|
||||
self.hf_text_config, "multi_query", False
|
||||
):
|
||||
# Multi-query attention, only one KV head.
|
||||
# Currently, tensor parallelism is not supported in this case.
|
||||
return 1
|
||||
|
||||
# For DBRX and MPT
|
||||
if self.hf_config.model_type in ["mpt"]:
|
||||
if "kv_n_heads" in self.hf_config.attn_config:
|
||||
return self.hf_config.attn_config["kv_n_heads"]
|
||||
return self.hf_config.num_attention_heads
|
||||
if self.hf_config.model_type in ["dbrx"]:
|
||||
return getattr(
|
||||
self.hf_config.attn_config,
|
||||
"kv_n_heads",
|
||||
self.hf_config.num_attention_heads,
|
||||
)
|
||||
if self.hf_config.model_type in ["nemotron-nas"]:
|
||||
nkvh = {
|
||||
self.hf_config.num_attention_heads // block.attention.n_heads_in_group
|
||||
for block in self.hf_config.block_configs
|
||||
if not block.attention.no_op
|
||||
}
|
||||
if len(nkvh) == 0:
|
||||
raise RuntimeError("Couldn't determine number of kv heads")
|
||||
if len(nkvh) > 1:
|
||||
raise ValueError(
|
||||
"Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang"
|
||||
)
|
||||
return next(iter(nkvh))
|
||||
|
||||
attributes = [
|
||||
# For Falcon:
|
||||
"n_head_kv",
|
||||
"num_kv_heads",
|
||||
# For LLaMA-2:
|
||||
"num_key_value_heads",
|
||||
# For ChatGLM:
|
||||
"multi_query_group_num",
|
||||
# For Step3
|
||||
"num_attention_groups",
|
||||
]
|
||||
for attr in attributes:
|
||||
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
return num_kv_heads
|
||||
|
||||
# For non-grouped-query attention models, the number of KV heads is
|
||||
# equal to the number of attention heads.
|
||||
return self.hf_text_config.num_attention_heads
|
||||
|
||||
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
# case where the number of KV heads is smaller than the tensor
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1, total_num_kv_heads // tensor_parallel_size)
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_cfg is None:
|
||||
# compressed-tensors uses a "compression_config" key
|
||||
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
||||
if quant_cfg is None:
|
||||
# check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
|
||||
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
|
||||
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
||||
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
|
||||
is_local = os.path.exists(self.model_path)
|
||||
modelopt_quant_config = {"quant_method": "modelopt"}
|
||||
if not is_local:
|
||||
import huggingface_hub
|
||||
|
||||
try:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hf_api = HfApi()
|
||||
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
||||
quant_cfg = modelopt_quant_config
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
logger.warning(
|
||||
"Offline mode is enabled, skipping hf_quant_config.json check"
|
||||
)
|
||||
pass
|
||||
|
||||
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
||||
quant_config_file = os.path.join(
|
||||
self.model_path, "hf_quant_config.json"
|
||||
)
|
||||
with open(quant_config_file) as f:
|
||||
quant_config_dict = json.load(f)
|
||||
json_quant_configs = quant_config_dict["quantization"]
|
||||
quant_algo = json_quant_configs.get("quant_algo", None)
|
||||
if quant_algo == "MIXED_PRECISION":
|
||||
quant_cfg = {"quant_method": "w4afp8"}
|
||||
else:
|
||||
quant_cfg = modelopt_quant_config
|
||||
return quant_cfg
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = [
|
||||
"awq",
|
||||
"gptq",
|
||||
"fp8",
|
||||
"compressed_tensors",
|
||||
"compressed-tensors",
|
||||
"fbgemm_fp8",
|
||||
"w8a8_fp8",
|
||||
"petit_nvfp4",
|
||||
"quark",
|
||||
"mxfp4",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
"marlin",
|
||||
"modelopt",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"fbgemm_fp8",
|
||||
"compressed_tensors",
|
||||
"compressed-tensors",
|
||||
"experts_int8",
|
||||
"w8a8_int8",
|
||||
"w8a8_fp8",
|
||||
"moe_wna16",
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
"petit_nvfp4",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"modelopt_fp4": ["modelopt"],
|
||||
"petit_nvfp4": ["modelopt"],
|
||||
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
|
||||
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
|
||||
}
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get(
|
||||
"quant_method", "" if not self.quantization else self.quantization
|
||||
).lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for _, method in QUANTIZATION_METHODS.items():
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization
|
||||
)
|
||||
if quantization_override:
|
||||
quant_method = quantization_override
|
||||
self.quantization = quantization_override
|
||||
break
|
||||
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
elif self.quantization != quant_method:
|
||||
if (
|
||||
self.quantization not in compatible_quantization_methods
|
||||
or quant_method
|
||||
not in compatible_quantization_methods[self.quantization]
|
||||
):
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization})."
|
||||
)
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}."
|
||||
)
|
||||
if is_hip() and self.quantization not in rocm_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm."
|
||||
)
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.",
|
||||
self.quantization,
|
||||
)
|
||||
|
||||
def _verify_dual_chunk_attention_config(self) -> None:
|
||||
if hasattr(self.hf_config, "dual_chunk_attention_config"):
|
||||
# Try loading the sparse attention config
|
||||
sparse_attn_config = get_sparse_attention_config(self.model_path)
|
||||
if not sparse_attn_config:
|
||||
return
|
||||
self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
|
||||
sparse_attn_config
|
||||
)
|
||||
if (
|
||||
"sparse_attention_enabled"
|
||||
not in self.hf_config.dual_chunk_attention_config
|
||||
):
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_enabled"
|
||||
] = True
|
||||
|
||||
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||
if eos_ids is not None:
|
||||
# it can be either int or list of int
|
||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||
if eos_ids is None:
|
||||
eos_ids = set()
|
||||
if self.hf_generation_config:
|
||||
generation_eos_ids = getattr(
|
||||
self.hf_generation_config, "eos_token_id", None
|
||||
)
|
||||
if generation_eos_ids:
|
||||
generation_eos_ids = (
|
||||
{generation_eos_ids}
|
||||
if isinstance(generation_eos_ids, int)
|
||||
else set(generation_eos_ids)
|
||||
)
|
||||
eos_ids = eos_ids | generation_eos_ids
|
||||
return eos_ids
|
||||
|
||||
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
"""
|
||||
Pull the model config files to a temporary
|
||||
directory in case of remote.
|
||||
|
||||
Args:
|
||||
model: The model name or path.
|
||||
|
||||
"""
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
from sglang.srt.utils import is_remote_url
|
||||
|
||||
if is_remote_url(self.model_path):
|
||||
logger.info("Pulling model configs from remote...")
|
||||
# BaseConnector implements __del__() to clean up the local dir.
|
||||
# Since config files need to exist all the time, so we DO NOT use
|
||||
# with statement to avoid closing the client.
|
||||
client = create_remote_connector(self.model_path)
|
||||
if is_remote_url(self.model_path):
|
||||
client.pull_files(allow_pattern=["*config.json"])
|
||||
self.model_weights = self.model_path
|
||||
self.model_path = client.get_local_dir()
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
"float": torch.float32,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
if isinstance(config_dtype, str):
|
||||
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type.startswith("gemma"):
|
||||
if config.model_type == "gemma":
|
||||
gemma_version = ""
|
||||
else:
|
||||
gemma_version = config.model_type[5]
|
||||
logger.info(
|
||||
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
else:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
|
||||
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architecture
|
||||
# 2. check the `is_embedding` server args
|
||||
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
or "InternLM2ForRewardModel" in model_architectures
|
||||
or "Qwen2ForRewardModel" in model_architectures
|
||||
or "Qwen2ForSequenceClassification" in model_architectures
|
||||
or "Qwen3ForSequenceClassification" in model_architectures
|
||||
or "CLIPModel" in model_architectures
|
||||
or "BertModel" in model_architectures
|
||||
or "Contriever" in model_architectures
|
||||
or "BertForSequenceClassification" in model_architectures
|
||||
or "XLMRobertaModel" in model_architectures
|
||||
or "XLMRobertaForSequenceClassification" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return not is_embedding
|
||||
|
||||
|
||||
multimodal_model_archs = [
|
||||
"CLIPModel",
|
||||
"DeepseekVL2ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
"Glm4vForConditionalGeneration",
|
||||
"Glm4vMoeForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaForConditionalGeneration",
|
||||
"LlavaVidForCausalLM",
|
||||
"MiniCPMO",
|
||||
"MiniCPMV",
|
||||
"Mistral3ForConditionalGeneration",
|
||||
"MultiModalityCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
"Qwen2AudioForConditionalGeneration",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"InternVLChatModel",
|
||||
"InternS1ForConditionalGeneration",
|
||||
"Phi4MMForCausalLM",
|
||||
"VILAForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
|
||||
|
||||
def is_multimodal_model(model_architectures: List[str]):
|
||||
if any(
|
||||
multi_model_arch in model_architectures
|
||||
for multi_model_arch in multimodal_model_archs
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_multimodal_gen_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_image_gen_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_audio_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
|
||||
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
|
||||
"""Check if chunked prefill is supported for a MultiModal model."""
|
||||
unsupported = [
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
"CLIPModel",
|
||||
]
|
||||
if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def is_hybrid_model(
|
||||
model_architectures: List[str],
|
||||
hybrid_kvcache_ratio: Optional[float],
|
||||
context_length: Optional[int],
|
||||
attention_chunk_size: Optional[int],
|
||||
):
|
||||
if hybrid_kvcache_ratio is None:
|
||||
return None
|
||||
elif (
|
||||
hybrid_kvcache_ratio > 0
|
||||
and model_architectures[0] == "Llama4ForConditionalGeneration"
|
||||
and context_length > attention_chunk_size
|
||||
):
|
||||
return hybrid_kvcache_ratio
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int):
|
||||
if "Llama4ForConditionalGeneration" in model_architectures:
|
||||
swa_attention_layer_ids = [
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 != 0
|
||||
]
|
||||
full_attention_layer_ids = [
|
||||
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
|
||||
]
|
||||
else:
|
||||
swa_attention_layer_ids = None
|
||||
full_attention_layer_ids = None
|
||||
return swa_attention_layer_ids, full_attention_layer_ids
|
||||
326
python/sglang/srt/configs/qwen3_next.py
Normal file
326
python/sglang/srt/configs/qwen3_next.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen3Hybrid model configuration"""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# NOTE: HybridLayerType
|
||||
class HybridLayerType(enum.Enum):
|
||||
full_attention = "attention"
|
||||
swa_attention = "swa_attention"
|
||||
linear_attention = "linear_attention"
|
||||
mamba2 = "mamba"
|
||||
|
||||
|
||||
class Qwen3NextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
|
||||
Qwen3-Next model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids`.
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 2):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
partial_rotary_factor (`float`, *optional*, defaults to 0.25):
|
||||
Percentage of the query and keys which will have rotary embedding.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
Projection weights dimension in multi-head attention.
|
||||
linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
|
||||
Kernel size of the convolution used in linear attention layers.
|
||||
linear_key_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of each key head in linear attention.
|
||||
linear_value_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of each value head in linear attention.
|
||||
linear_num_key_heads (`int`, *optional*, defaults to 16):
|
||||
Number of key heads used in linear attention layers.
|
||||
linear_num_value_heads (`int`, *optional*, defaults to 32):
|
||||
Number of value heads used in linear attention layers.
|
||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 512):
|
||||
Intermediate size of the routed expert.
|
||||
shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
|
||||
Intermediate size of the shared expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 10):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 512):
|
||||
Number of routed experts.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the topk probabilities.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
layer_types (`list[str]`, *optional*, defaults to None):
|
||||
Types of each layer (attention or linear).
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3NextModel, Qwen3NextConfig
|
||||
|
||||
>>> # Initializing a Qwen3Next style configuration
|
||||
>>> configuration = Qwen3NextConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
|
||||
>>> model = Qwen3NextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "qwen3_next"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=5632,
|
||||
num_hidden_layers=48,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=2,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
partial_rotary_factor=0.25,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
head_dim=256,
|
||||
linear_conv_kernel_dim=4,
|
||||
linear_key_head_dim=128,
|
||||
linear_value_head_dim=128,
|
||||
linear_num_key_heads=16,
|
||||
linear_num_value_heads=32,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=512,
|
||||
shared_expert_intermediate_size=512,
|
||||
num_experts_per_tok=10,
|
||||
num_experts=512,
|
||||
norm_topk_prob=True,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=[],
|
||||
layer_types=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.head_dim = head_dim
|
||||
rope_config_validation(self)
|
||||
|
||||
# linear attention (gdn now part)
|
||||
self.linear_conv_kernel_dim = linear_conv_kernel_dim
|
||||
self.linear_key_head_dim = linear_key_head_dim
|
||||
self.linear_value_head_dim = linear_value_head_dim
|
||||
self.linear_num_key_heads = linear_num_key_heads
|
||||
self.linear_num_value_heads = linear_num_value_heads
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = mlp_only_layers
|
||||
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
layer_type_list = []
|
||||
|
||||
for l in range(self.num_hidden_layers):
|
||||
if (l + 1) % self.full_attention_interval == 0:
|
||||
layer_type_list.append(HybridLayerType.full_attention.value)
|
||||
else:
|
||||
layer_type_list.append(HybridLayerType.linear_attention.value)
|
||||
|
||||
return layer_type_list
|
||||
|
||||
@property
|
||||
def linear_layer_ids(self):
|
||||
return [
|
||||
i
|
||||
for i, type_value in enumerate(self.layers_block_type)
|
||||
if type_value == HybridLayerType.linear_attention.value
|
||||
]
|
||||
|
||||
@property
|
||||
def full_attention_layer_ids(self):
|
||||
return [
|
||||
i
|
||||
for i, type_value in enumerate(self.layers_block_type)
|
||||
if type_value == HybridLayerType.full_attention.value
|
||||
]
|
||||
|
||||
@property
|
||||
def hybrid_gdn_params(self):
|
||||
world_size = get_attention_tp_size()
|
||||
conv_dim = (
|
||||
self.linear_key_head_dim * self.linear_num_key_heads * 2
|
||||
+ self.linear_value_head_dim * self.linear_num_value_heads
|
||||
)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.linear_conv_kernel_dim - 1,
|
||||
)
|
||||
|
||||
temporal_state_shape = (
|
||||
divide(self.linear_num_value_heads, world_size),
|
||||
self.linear_key_head_dim,
|
||||
self.linear_value_head_dim,
|
||||
)
|
||||
conv_dtype = torch.bfloat16
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
mamba_layers = self.linear_layer_ids
|
||||
return (
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self):
|
||||
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
||||
self.hybrid_gdn_params
|
||||
)
|
||||
mamba_layers_len = len(mamba_layers)
|
||||
|
||||
return (
|
||||
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
||||
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
||||
) * mamba_layers_len
|
||||
172
python/sglang/srt/configs/step3_vl.py
Normal file
172
python/sglang/srt/configs/step3_vl.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Step3VisionEncoderConfig(PretrainedConfig):
|
||||
model_type = "step3_vision_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1792,
|
||||
intermediate_size=3072,
|
||||
output_hidden_size=4096,
|
||||
num_hidden_layers=63,
|
||||
num_attention_heads=16,
|
||||
num_channels=3,
|
||||
image_size=728,
|
||||
patch_size=14,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-5,
|
||||
**kwargs,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.output_hidden_size = output_hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Step3TextConfig(PretrainedConfig):
|
||||
model_type = "step3_text"
|
||||
architectures = ["Step3TextForCausalLM"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 7168,
|
||||
intermediate_size: int = 18432,
|
||||
num_attention_heads: int = 64,
|
||||
num_attention_groups: int = 1,
|
||||
num_hidden_layers: int = 61,
|
||||
max_seq_len: int = 65536,
|
||||
vocab_size: int = 128815,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
moe_intermediate_size: int = 5120,
|
||||
moe_num_experts: int = 48,
|
||||
moe_top_k: int = 3,
|
||||
rope_theta: float = 500000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embedding: int = 65536,
|
||||
share_expert_dim: int = 5120,
|
||||
share_q_dim: int = 2048,
|
||||
head_dim: int = 256,
|
||||
norm_expert_weight: bool = False,
|
||||
moe_layers_enum: tuple[int] = (
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
),
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_attention_groups = num_attention_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.max_position_embedding = max_position_embedding
|
||||
self.share_expert_dim = share_expert_dim
|
||||
self.share_q_dim = share_q_dim
|
||||
self.head_dim = head_dim
|
||||
self.norm_expert_weight = norm_expert_weight
|
||||
self.moe_layers_enum = moe_layers_enum
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Step3VLConfig(PretrainedConfig):
|
||||
model_type = "step3_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
|
||||
text_config: Optional[Union[dict, Step3TextConfig]] = None,
|
||||
understand_projector_stride: int = 1,
|
||||
projector_bias: bool = True,
|
||||
image_token_id: int = 128001,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if vision_config is None:
|
||||
vision_config = Step3VisionEncoderConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = Step3VisionEncoderConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is None:
|
||||
text_config = Step3TextConfig()
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = Step3TextConfig(**text_config)
|
||||
self.text_config = text_config
|
||||
|
||||
self.understand_projector_stride = understand_projector_stride
|
||||
self.projector_bias = projector_bias
|
||||
self.hidden_size = text_config.hidden_size
|
||||
self.image_token_id = image_token_id
|
||||
|
||||
super().__init__(**kwargs)
|
||||
156
python/sglang/srt/configs/update_config.py
Normal file
156
python/sglang/srt/configs/update_config.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
DEFAULT_MOE_PADDING_SIZE = 32
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
|
||||
|
||||
def may_get_weight_block_size(model_config, load_config):
|
||||
from sglang.srt.model_loader.loader import _get_quantization_config
|
||||
from sglang.srt.model_loader.utils import get_model_architecture
|
||||
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
||||
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
)
|
||||
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size"):
|
||||
return getattr(quant_config, "weight_block_size")
|
||||
return None
|
||||
|
||||
|
||||
def get_moe_padding_size(weight_block_size):
|
||||
if weight_block_size is not None:
|
||||
# See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||
assert (
|
||||
len(weight_block_size) == 2
|
||||
), "Only len(weight_block_size) == 2 is supported"
|
||||
assert (
|
||||
weight_block_size[0] == weight_block_size[1]
|
||||
), "Only weight_block_size[0] == weight_block_size[1] is supported"
|
||||
|
||||
return weight_block_size[0]
|
||||
|
||||
return DEFAULT_MOE_PADDING_SIZE
|
||||
|
||||
|
||||
def get_num_heads_padding_size(tp_size, weight_block_size):
|
||||
pad_size = (
|
||||
tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
|
||||
)
|
||||
return pad_size
|
||||
|
||||
|
||||
def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
|
||||
attr_value = intermediate_padding_size
|
||||
if hasattr(model_config, "hf_config") and hasattr(
|
||||
model_config.hf_config, attr_name
|
||||
):
|
||||
attr_value = getattr(model_config.hf_config, attr_name)
|
||||
elif hasattr(model_config, attr_name):
|
||||
attr_value = getattr(model_config, attr_name)
|
||||
|
||||
if attr_value % intermediate_padding_size != 0:
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
|
||||
if hasattr(model_config, "hf_config"):
|
||||
setattr(model_config.hf_config, attr_name, attr_value)
|
||||
if hasattr(model_config, "hf_text_config"):
|
||||
setattr(model_config.hf_text_config, attr_name, attr_value)
|
||||
else:
|
||||
setattr(model_config, attr_name, attr_value)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def adjust_config_with_unaligned_cpu_tp(
|
||||
model_config: ModelConfig, load_config: LoadConfig, tp_size: int
|
||||
) -> ModelConfig:
|
||||
# Support the case where the num_attention_heads is not divisible by the TP size.
|
||||
weight_block_size = may_get_weight_block_size(model_config, load_config)
|
||||
|
||||
model_config.hf_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
model_config.hf_text_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
|
||||
model_config.hf_config.original_total_num_kv_heads = (
|
||||
model_config.get_total_num_kv_heads()
|
||||
)
|
||||
model_config.hf_text_config.original_total_num_kv_heads = (
|
||||
model_config.get_total_num_kv_heads()
|
||||
)
|
||||
|
||||
if (
|
||||
model_config.num_attention_heads % tp_size != 0
|
||||
or model_config.get_total_num_kv_heads() % tp_size != 0
|
||||
):
|
||||
# Compute the head_dim using the model_config.num_attention_heads before padding
|
||||
if not hasattr(model_config.hf_config, "head_dim"):
|
||||
model_config.hf_config.head_dim = (
|
||||
model_config.hidden_size // model_config.num_attention_heads
|
||||
)
|
||||
|
||||
query_heads_per_kv = (
|
||||
model_config.num_attention_heads // model_config.get_total_num_kv_heads()
|
||||
)
|
||||
total_kv_heads = model_config.get_total_num_kv_heads()
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
|
||||
num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
|
||||
|
||||
model_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.hf_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.hf_text_config.num_key_value_heads = num_key_value_heads
|
||||
|
||||
num_attention_heads = num_key_value_heads * query_heads_per_kv
|
||||
model_config.num_attention_heads = num_attention_heads
|
||||
model_config.hf_config.num_attention_heads = num_attention_heads
|
||||
model_config.hf_text_config.num_attention_heads = num_attention_heads
|
||||
|
||||
intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "moe_intermediate_size", intermediate_padding_size
|
||||
)
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "intermediate_size", intermediate_padding_size
|
||||
)
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "intermediate_size_mlp", intermediate_padding_size
|
||||
)
|
||||
if (
|
||||
hasattr(model_config.hf_config, "vision_config")
|
||||
and model_config.hf_config.vision_config.model_type == "siglip_vision_model"
|
||||
):
|
||||
model_config.hf_config.vision_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
if model_config.hf_config.vision_config.num_attention_heads % tp_size != 0:
|
||||
model_config.hf_config.vision_config.head_dim = (
|
||||
model_config.hf_config.vision_config.hidden_size
|
||||
// model_config.hf_config.vision_config.num_attention_heads
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
|
||||
model_config.hf_config.vision_config.num_attention_heads = pad_vocab_size(
|
||||
model_config.hf_config.vision_config.num_attention_heads, pad_size
|
||||
)
|
||||
model_config.hf_config.vision_config = update_intermediate_size(
|
||||
model_config.hf_config.vision_config,
|
||||
"intermediate_size",
|
||||
intermediate_padding_size,
|
||||
)
|
||||
|
||||
return model_config
|
||||
25
python/sglang/srt/configs/utils.py
Normal file
25
python/sglang/srt/configs/utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Type
|
||||
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
BaseImageProcessor,
|
||||
PretrainedConfig,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
|
||||
def register_image_processor(
|
||||
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
|
||||
):
|
||||
"""
|
||||
register customized hf image processor while removing hf impl
|
||||
"""
|
||||
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
|
||||
|
||||
|
||||
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
|
||||
"""
|
||||
register customized hf processor while removing hf impl
|
||||
"""
|
||||
AutoProcessor.register(config, processor, exist_ok=True)
|
||||
51
python/sglang/srt/connector/__init__.py
Normal file
51
python/sglang/srt/connector/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from sglang.srt.connector.base_connector import (
|
||||
BaseConnector,
|
||||
BaseFileConnector,
|
||||
BaseKVConnector,
|
||||
)
|
||||
from sglang.srt.connector.redis import RedisConnector
|
||||
from sglang.srt.connector.s3 import S3Connector
|
||||
from sglang.srt.utils import parse_connector_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorType(str, enum.Enum):
|
||||
FS = "filesystem"
|
||||
KV = "KV"
|
||||
|
||||
|
||||
def create_remote_connector(url, **kwargs) -> BaseConnector:
|
||||
connector_type = parse_connector_type(url)
|
||||
if connector_type == "redis":
|
||||
return RedisConnector(url)
|
||||
elif connector_type == "s3":
|
||||
return S3Connector(url)
|
||||
else:
|
||||
raise ValueError(f"Invalid connector type: {url}")
|
||||
|
||||
|
||||
def get_connector_type(client: BaseConnector) -> ConnectorType:
|
||||
if isinstance(client, BaseKVConnector):
|
||||
return ConnectorType.KV
|
||||
if isinstance(client, BaseFileConnector):
|
||||
return ConnectorType.FS
|
||||
|
||||
raise ValueError(f"Invalid connector type: {client}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseConnector",
|
||||
"BaseFileConnector",
|
||||
"BaseKVConnector",
|
||||
"RedisConnector",
|
||||
"S3Connector",
|
||||
"ConnectorType",
|
||||
"create_remote_connector",
|
||||
"get_connector_type",
|
||||
]
|
||||
111
python/sglang/srt/connector/base_connector.py
Normal file
111
python/sglang/srt/connector/base_connector.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseConnector(ABC):
|
||||
"""
|
||||
For fs connector such as s3:
|
||||
<connector_type>://<path>/<filename>
|
||||
|
||||
For kv connector such as redis:
|
||||
<connector_type>://<host>:<port>/<model_name>/keys/<key>
|
||||
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
||||
"""
|
||||
|
||||
def __init__(self, url: str):
|
||||
self.url = url
|
||||
self.closed = False
|
||||
self.local_dir = tempfile.mkdtemp()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
existing_handler = signal.getsignal(sig)
|
||||
signal.signal(sig, self._close_by_signal(existing_handler))
|
||||
|
||||
def get_local_dir(self):
|
||||
return self.local_dir
|
||||
|
||||
@abstractmethod
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[List[str]] = None,
|
||||
ignore_pattern: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.closed = True
|
||||
if os.path.exists(self.local_dir):
|
||||
shutil.rmtree(self.local_dir)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def _close_by_signal(self, existing_handler=None):
|
||||
|
||||
def new_handler(signum, frame):
|
||||
self.close()
|
||||
if existing_handler:
|
||||
existing_handler(signum, frame)
|
||||
|
||||
return new_handler
|
||||
|
||||
|
||||
class BaseKVConnector(BaseConnector):
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> Optional[torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def getstr(self, key: str) -> Optional[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, obj: torch.Tensor) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def setstr(self, key: str, obj: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def list(self, prefix: str) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BaseFileConnector(BaseConnector):
|
||||
"""
|
||||
List full file names from remote fs path and filter by allow pattern.
|
||||
|
||||
Args:
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
|
||||
Returns:
|
||||
list[str]: List of full paths allowed by the pattern
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def glob(self, allow_pattern: str) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
85
python/sglang/srt/connector/redis.py
Normal file
85
python/sglang/srt/connector/redis.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector import BaseKVConnector
|
||||
from sglang.srt.connector.serde import create_serde
|
||||
from sglang.srt.connector.utils import pull_files_from_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisConnector(BaseKVConnector):
|
||||
|
||||
def __init__(self, url: str):
|
||||
import redis
|
||||
|
||||
super().__init__(url)
|
||||
parsed_url = urlparse(url)
|
||||
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
||||
self.model_name = parsed_url.path.lstrip("/")
|
||||
# TODO: more serde options
|
||||
self.s, self.d = create_serde("safe")
|
||||
|
||||
def get(self, key: str) -> Optional[torch.Tensor]:
|
||||
val = self.connection.get(key)
|
||||
|
||||
if val is None:
|
||||
logger.error("Key %s not found", key)
|
||||
return None
|
||||
|
||||
return self.d.from_bytes(val)
|
||||
|
||||
def getstr(self, key: str) -> Optional[str]:
|
||||
val = self.connection.get(key)
|
||||
if val is None:
|
||||
logger.error("Key %s not found", key)
|
||||
return None
|
||||
|
||||
return val.decode("utf-8")
|
||||
|
||||
def set(self, key: str, tensor: torch.Tensor) -> None:
|
||||
assert tensor is not None
|
||||
self.connection.set(key, self.s.to_bytes(tensor))
|
||||
|
||||
def setstr(self, key: str, obj: str) -> None:
|
||||
self.connection.set(key, obj)
|
||||
|
||||
def list(self, prefix: str) -> List[str]:
|
||||
cursor = 0
|
||||
all_keys: List[bytes] = []
|
||||
|
||||
while True:
|
||||
ret: Tuple[int, List[bytes]] = self.connection.scan(
|
||||
cursor=cursor, match=f"{prefix}*"
|
||||
) # type: ignore
|
||||
cursor, keys = ret
|
||||
all_keys.extend(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
return [key.decode("utf-8") for key in all_keys]
|
||||
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, bytes], None, None]:
|
||||
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
|
||||
for key in keys:
|
||||
val = self.get(key)
|
||||
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
|
||||
yield key, val
|
||||
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[List[str]] = None,
|
||||
ignore_pattern: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
|
||||
|
||||
def close(self):
|
||||
self.connection.close()
|
||||
super().close()
|
||||
122
python/sglang/srt/connector/s3.py
Normal file
122
python/sglang/srt/connector/s3.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import fnmatch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector import BaseFileConnector
|
||||
|
||||
|
||||
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path
|
||||
for path in paths
|
||||
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
return [
|
||||
path
|
||||
for path in paths
|
||||
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
||||
]
|
||||
|
||||
|
||||
def list_files(
|
||||
s3,
|
||||
path: str,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> tuple[str, str, list[str]]:
|
||||
"""
|
||||
List files from S3 path and filter by pattern.
|
||||
|
||||
Args:
|
||||
s3: S3 client to use.
|
||||
path: The S3 path to list from.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
Returns:
|
||||
tuple[str, str, list[str]]: A tuple where:
|
||||
- The first element is the bucket name
|
||||
- The second element is string represent the bucket
|
||||
and the prefix as a dir like string
|
||||
- The third element is a list of files allowed or
|
||||
disallowed by pattern
|
||||
"""
|
||||
parts = path.removeprefix("s3://").split("/")
|
||||
prefix = "/".join(parts[1:])
|
||||
bucket_name = parts[0]
|
||||
|
||||
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
||||
paths = [obj["Key"] for obj in objects.get("Contents", [])]
|
||||
|
||||
paths = _filter_ignore(paths, ["*/"])
|
||||
if allow_pattern is not None:
|
||||
paths = _filter_allow(paths, allow_pattern)
|
||||
|
||||
if ignore_pattern is not None:
|
||||
paths = _filter_ignore(paths, ignore_pattern)
|
||||
|
||||
return bucket_name, prefix, paths
|
||||
|
||||
|
||||
class S3Connector(BaseFileConnector):
|
||||
|
||||
def __init__(self, url: str) -> None:
|
||||
import boto3
|
||||
|
||||
super().__init__(url)
|
||||
self.client = boto3.client("s3")
|
||||
|
||||
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
|
||||
bucket_name, _, paths = list_files(
|
||||
self.client, path=self.url, allow_pattern=allow_pattern
|
||||
)
|
||||
return [f"s3://{bucket_name}/{path}" for path in paths]
|
||||
|
||||
def pull_files(
|
||||
self,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Pull files from S3 storage into the temporary directory.
|
||||
|
||||
Args:
|
||||
s3_model_path: The S3 path of the model.
|
||||
allow_pattern: A list of patterns of which files to pull.
|
||||
ignore_pattern: A list of patterns of which files not to pull.
|
||||
|
||||
"""
|
||||
bucket_name, base_dir, files = list_files(
|
||||
self.client, self.url, allow_pattern, ignore_pattern
|
||||
)
|
||||
if len(files) == 0:
|
||||
return
|
||||
|
||||
for file in files:
|
||||
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
|
||||
local_dir = Path(destination_file).parent
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
self.client.download_file(bucket_name, file, destination_file)
|
||||
|
||||
def weight_iterator(
|
||||
self, rank: int = 0
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
|
||||
# only support safetensor files now
|
||||
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
|
||||
return runai_safetensors_weights_iterator(hf_weights_files)
|
||||
|
||||
def close(self):
|
||||
self.client.close()
|
||||
super().close()
|
||||
31
python/sglang/srt/connector/serde/__init__.py
Normal file
31
python/sglang/srt/connector/serde/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# inspired by LMCache
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
|
||||
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
||||
|
||||
|
||||
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
|
||||
s: Optional[Serializer] = None
|
||||
d: Optional[Deserializer] = None
|
||||
|
||||
if serde_type == "safe":
|
||||
s = SafeSerializer()
|
||||
d = SafeDeserializer()
|
||||
else:
|
||||
raise ValueError(f"Unknown serde type: {serde_type}")
|
||||
|
||||
return s, d
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Serializer",
|
||||
"Deserializer",
|
||||
"SafeSerializer",
|
||||
"SafeDeserializer",
|
||||
"create_serde",
|
||||
]
|
||||
30
python/sglang/srt/connector/serde/safe_serde.py
Normal file
30
python/sglang/srt/connector/serde/safe_serde.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load, save
|
||||
|
||||
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
||||
|
||||
|
||||
class SafeSerializer(Serializer):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def to_bytes(self, t: torch.Tensor) -> bytes:
|
||||
return save({"tensor_bytes": t.cpu().contiguous()})
|
||||
|
||||
|
||||
class SafeDeserializer(Deserializer):
|
||||
|
||||
def __init__(self):
|
||||
# TODO: dtype options
|
||||
super().__init__(torch.float32)
|
||||
|
||||
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||
return load(bytes(b))["tensor_bytes"]
|
||||
|
||||
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
||||
return self.from_bytes_normal(b)
|
||||
43
python/sglang/srt/connector/serde/serde.py
Normal file
43
python/sglang/srt/connector/serde/serde.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import abc
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def to_bytes(self, t: torch.Tensor) -> bytes:
|
||||
"""
|
||||
Serialize a pytorch tensor to bytes. The serialized bytes should contain
|
||||
both the data and the metadata (shape, dtype, etc.) of the tensor.
|
||||
|
||||
Input:
|
||||
t: the input pytorch tensor, can be on any device, in any shape,
|
||||
with any dtype
|
||||
|
||||
Returns:
|
||||
bytes: the serialized bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Deserializer(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, dtype):
|
||||
self.dtype = dtype
|
||||
|
||||
@abstractmethod
|
||||
def from_bytes(self, bs: bytes) -> torch.Tensor:
|
||||
"""
|
||||
Deserialize a pytorch tensor from bytes.
|
||||
|
||||
Input:
|
||||
bytes: a stream of bytes
|
||||
|
||||
Output:
|
||||
torch.Tensor: the deserialized pytorch tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
35
python/sglang/srt/connector/utils.py
Normal file
35
python/sglang/srt/connector/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.connector import BaseConnector
|
||||
|
||||
|
||||
def parse_model_name(url: str) -> str:
|
||||
"""
|
||||
Parse the model name from the url.
|
||||
Only used for db connector
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
return parsed_url.path.lstrip("/")
|
||||
|
||||
|
||||
def pull_files_from_db(
|
||||
connector: BaseConnector,
|
||||
model_name: str,
|
||||
allow_pattern: Optional[list[str]] = None,
|
||||
ignore_pattern: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
prefix = f"{model_name}/files/"
|
||||
local_dir = connector.get_local_dir()
|
||||
files = connector.list(prefix)
|
||||
|
||||
for file in files:
|
||||
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
|
||||
local_dir = Path(destination_file).parent
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
with open(destination_file, "wb") as f:
|
||||
f.write(connector.getstr(file).encode("utf-8"))
|
||||
3
python/sglang/srt/constants.py
Normal file
3
python/sglang/srt/constants.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# GPU Memory Types
|
||||
GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
|
||||
GPU_MEMORY_TYPE_WEIGHTS = "weights"
|
||||
213
python/sglang/srt/constrained/base_grammar_backend.py
Normal file
213
python/sglang/srt/constrained/base_grammar_backend.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Event
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseGrammarObject:
|
||||
|
||||
def __init__(self):
|
||||
self._finished = False
|
||||
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
Accept a token in the grammar.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def rollback(self, k: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_terminated(self):
|
||||
return False
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def copy(self) -> "BaseGrammarObject":
|
||||
return self
|
||||
|
||||
@property
|
||||
def finished(self):
|
||||
return self._finished
|
||||
|
||||
@finished.setter
|
||||
def finished(self, finished):
|
||||
self._finished = finished
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
"""
|
||||
Try to jump forward in the grammar.
|
||||
|
||||
Returns:
|
||||
A jump forward helper which may be used in `jump_forward_str_state`.
|
||||
None if the jump forward is not possible.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
"""
|
||||
Jump forward for the grammar.
|
||||
|
||||
Returns:
|
||||
A tuple of the jump forward string and the next state of the grammar
|
||||
(which can be used in `jump_and_retokenize` if needed).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
) -> None:
|
||||
"""
|
||||
Jump forward occurs, and update the grammar state if needed.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
INVALID_GRAMMAR_OBJ = BaseGrammarObject()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
value: BaseGrammarObject
|
||||
event: Event
|
||||
|
||||
|
||||
class BaseGrammarBackend:
|
||||
def __init__(self):
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
||||
|
||||
def _not_supported(self, key_type: str, key_string: str) -> None:
|
||||
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
|
||||
|
||||
def dispatch_fallback(
|
||||
self, key_type: str, key_string: str
|
||||
) -> Optional[BaseGrammarObject]:
|
||||
"""
|
||||
This function should not be reached in any case.
|
||||
"""
|
||||
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||
return self._not_supported("json", key_string)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||
return self._not_supported("regex", key_string)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||
return self._not_supported("ebnf", key_string)
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||
return self._not_supported("structural_tag", key_string)
|
||||
|
||||
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
return self.dispatch_json(key_string)
|
||||
elif key_type == "regex":
|
||||
return self.dispatch_regex(key_string)
|
||||
elif key_type == "ebnf":
|
||||
return self.dispatch_ebnf(key_string)
|
||||
elif key_type == "structural_tag":
|
||||
return self.dispatch_structural_tag(key_string)
|
||||
elif key_type == "structural_pattern":
|
||||
return self.dispatch_structural_pattern(key_string)
|
||||
else:
|
||||
return self.dispatch_fallback(key_type, key_string)
|
||||
|
||||
def get_cached_or_future_value(
|
||||
self, key: Tuple[str, str]
|
||||
) -> Optional[BaseGrammarObject]:
|
||||
value = self.cache.get(key)
|
||||
if value:
|
||||
return value.copy(), True
|
||||
value = self.executor.submit(self._init_value_dispatch, key)
|
||||
return value, False
|
||||
|
||||
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
|
||||
self.cache[key] = value
|
||||
|
||||
def reset(self):
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_grammar_backend(
|
||||
server_args: ServerArgs,
|
||||
tokenizer,
|
||||
vocab_size: int,
|
||||
eos_token_ids: Optional[set] = None,
|
||||
) -> Optional[BaseGrammarBackend]:
|
||||
if server_args.grammar_backend == "outlines":
|
||||
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
|
||||
|
||||
grammar_backend = OutlinesGrammarBackend(
|
||||
tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
||||
|
||||
# Convert Set[int] to List[int] if needed
|
||||
eos_list = list(eos_token_ids) if eos_token_ids else None
|
||||
|
||||
grammar_backend = XGrammarGrammarBackend(
|
||||
tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
|
||||
)
|
||||
elif server_args.grammar_backend == "llguidance":
|
||||
from sglang.srt.constrained.llguidance_backend import GuidanceBackend
|
||||
|
||||
grammar_backend = GuidanceBackend(
|
||||
tokenizer=tokenizer,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
)
|
||||
elif server_args.grammar_backend == "none":
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
||||
|
||||
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
||||
from sglang.srt.constrained.reasoner_grammar_backend import (
|
||||
ReasonerGrammarBackend,
|
||||
)
|
||||
|
||||
grammar_backend = ReasonerGrammarBackend(
|
||||
grammar_backend, tokenizer.think_end_id
|
||||
)
|
||||
|
||||
return grammar_backend
|
||||
174
python/sglang/srt/constrained/llguidance_backend.py
Normal file
174
python/sglang/srt/constrained/llguidance_backend.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Constrained decoding with llguidance backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from
|
||||
from llguidance.hf import from_tokenizer
|
||||
from llguidance.torch import (
|
||||
allocate_token_bitmask,
|
||||
apply_token_bitmask_inplace,
|
||||
fill_next_token_bitmask,
|
||||
)
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuidanceGrammar(BaseGrammarObject):
|
||||
|
||||
def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):
|
||||
super().__init__()
|
||||
self.llguidance_tokenizer = llguidance_tokenizer
|
||||
self.serialized_grammar = serialized_grammar
|
||||
|
||||
self.ll_matcher = LLMatcher(
|
||||
self.llguidance_tokenizer,
|
||||
self.serialized_grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
self.finished = False
|
||||
self.bitmask = None
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if not self.ll_matcher.consume_token(token):
|
||||
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
|
||||
self.finished = True
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
if self.ll_matcher.is_stopped():
|
||||
self.finished = True
|
||||
|
||||
fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
if self.bitmask is None or self.bitmask.shape[0] < batch_size:
|
||||
# only create bitmask when batch gets larger
|
||||
self.bitmask = allocate_token_bitmask(
|
||||
batch_size, self.llguidance_tokenizer.vocab_size
|
||||
)
|
||||
bitmask = self.bitmask
|
||||
else:
|
||||
bitmask = self.bitmask[:batch_size]
|
||||
|
||||
return bitmask
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask.to(device, non_blocking=True)
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
apply_token_bitmask_inplace(logits, vocab_mask)
|
||||
|
||||
def copy(self):
|
||||
return GuidanceGrammar(
|
||||
llguidance_tokenizer=self.llguidance_tokenizer,
|
||||
serialized_grammar=self.serialized_grammar,
|
||||
)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
||||
if ff_tokens:
|
||||
return ff_tokens, ""
|
||||
else:
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
return "", -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class GuidanceBackend(BaseGrammarBackend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
whitespace_pattern: Optional[str] = None,
|
||||
n_vocab: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.whitespace_pattern = whitespace_pattern
|
||||
self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
|
||||
|
||||
def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
|
||||
try:
|
||||
return GuidanceGrammar(
|
||||
llguidance_tokenizer=self.llguidance_tokenizer,
|
||||
serialized_grammar=serialized_grammar,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Hit invalid grammar: {serialized_grammar=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
try:
|
||||
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
||||
key_string,
|
||||
defaults={
|
||||
"whitespace_pattern": self.whitespace_pattern,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_serialized(serialized_grammar)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
serialized_grammar = grammar_from("regex", key_string)
|
||||
return self._from_serialized(serialized_grammar)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
try:
|
||||
serialized_grammar = grammar_from("ebnf", key_string)
|
||||
return self._from_serialized(serialized_grammar)
|
||||
except ValueError as e:
|
||||
logger.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
try:
|
||||
structural_tag = json.loads(key_string)
|
||||
tags = [
|
||||
StructTag(
|
||||
begin=structure["begin"],
|
||||
grammar=structure["schema"],
|
||||
end=structure["end"],
|
||||
trigger=structural_tag["triggers"][0], # TODO?
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
g = StructTag.to_grammar(tags)
|
||||
return self._from_serialized(g)
|
||||
except Exception as e:
|
||||
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
191
python/sglang/srt/constrained/outlines_backend.py
Normal file
191
python/sglang/srt/constrained/outlines_backend.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Constrained decoding with outlines backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import interegular
|
||||
import torch
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
||||
|
||||
try:
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
except ImportError:
|
||||
from outlines_core.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutlinesGrammar(BaseGrammarObject):
|
||||
def __init__(
|
||||
self,
|
||||
guide: RegexGuide,
|
||||
jump_forward_map: Union[OutlinesJumpForwardMap, None],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.guide = guide
|
||||
self.jump_forward_map = jump_forward_map
|
||||
self.state = 0
|
||||
self.finished = False
|
||||
|
||||
def accept_token(self, token: int):
|
||||
self.state = self.guide.get_next_state(self.state, token)
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
tokens = torch.tensor(
|
||||
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
||||
).to(vocab_mask.device, non_blocking=True)
|
||||
vocab_mask = vocab_mask[idx]
|
||||
vocab_mask.fill_(1)
|
||||
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
||||
logits.masked_fill_(vocab_mask, float("-inf"))
|
||||
|
||||
def copy(self):
|
||||
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
||||
if not self.jump_forward_map:
|
||||
return None
|
||||
|
||||
jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
|
||||
if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
|
||||
return None
|
||||
|
||||
# preprocess the jump forward string
|
||||
suffix_bytes = []
|
||||
continuation_range = range(0x80, 0xC0)
|
||||
cur_state = self.state
|
||||
while (
|
||||
len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
|
||||
):
|
||||
# continuation bytes
|
||||
byte_edge = jump_forward_bytes.pop(0)
|
||||
suffix_bytes.append(byte_edge[0])
|
||||
cur_state = byte_edge[1]
|
||||
|
||||
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
||||
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
||||
return suffix_ids, cur_state
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, cur_state = helper
|
||||
return self.jump_forward_map.jump_forward_symbol(cur_state)
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
self.state = next_state
|
||||
|
||||
|
||||
class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
whitespace_pattern: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
try:
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
except AttributeError:
|
||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||
origin_pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
def fset(self, value):
|
||||
self._value = value
|
||||
|
||||
type(tokenizer).pad_token_id = property(
|
||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||
)
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token = (
|
||||
self.outlines_tokenizer.tokenizer.pad_token
|
||||
)
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
self.whitespace_pattern = whitespace_pattern
|
||||
|
||||
def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
|
||||
try:
|
||||
if hasattr(RegexGuide, "from_regex"):
|
||||
# outlines >= 0.1.1
|
||||
guide = RegexGuide.from_regex(regex, self.outlines_tokenizer)
|
||||
else:
|
||||
# outlines <= 0.0.46
|
||||
guide = RegexGuide(regex, self.outlines_tokenizer)
|
||||
except interegular.patterns.InvalidSyntax as e:
|
||||
logger.error(f"Hit invalid regex schema: {regex=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
|
||||
jump_forward_map = None
|
||||
return OutlinesGrammar(guide, jump_forward_map)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str):
|
||||
return super().dispatch_ebnf(key_string)
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str):
|
||||
return super().dispatch_structural_tag(key_string)
|
||||
|
||||
def dispatch_json(self, key_string: str):
|
||||
try:
|
||||
regex = build_regex_from_object(
|
||||
key_string,
|
||||
whitespace_pattern=self.whitespace_pattern,
|
||||
)
|
||||
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._compile_regex(regex)
|
||||
|
||||
def dispatch_regex(self, key_string: str):
|
||||
return self._compile_regex(key_string)
|
||||
|
||||
|
||||
def build_regex_from_object(
|
||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||
):
|
||||
if isinstance(object, type(BaseModel)):
|
||||
schema = json.dumps(object.model_json_schema())
|
||||
elif isinstance(object, Dict):
|
||||
schema = json.dumps(object)
|
||||
else:
|
||||
schema = object
|
||||
return build_regex_from_schema(schema, whitespace_pattern)
|
||||
200
python/sglang/srt/constrained/outlines_jump_forward.py
Normal file
200
python/sglang/srt/constrained/outlines_jump_forward.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Faster constrained decoding with jump forward decoding / compressed finite state machine.
|
||||
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import interegular
|
||||
from interegular import InvalidSyntax
|
||||
from outlines.caching import cache
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
try:
|
||||
# outlines >= 0.1.0
|
||||
from outlines_core.fsm.outlines_core_rs import FSMInfo
|
||||
from outlines_core.fsm.regex import make_byte_level_fsm, make_deterministic_fsm
|
||||
except ImportError:
|
||||
# outlines <= 0.0.46
|
||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
|
||||
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JumpEdge:
|
||||
symbol: str = None
|
||||
symbol_next_state: int = None
|
||||
byte: int = None
|
||||
byte_next_state: int = None
|
||||
|
||||
|
||||
def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
|
||||
if not DISABLE_DISK_CACHE:
|
||||
return cache(expire, typed, ignore)
|
||||
else:
|
||||
return lambda fn: None
|
||||
|
||||
|
||||
@disk_cache()
|
||||
def init_state_to_jump_forward(regex_string):
|
||||
try:
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
||||
return
|
||||
|
||||
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
|
||||
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
||||
|
||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||
|
||||
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
||||
id_to_symbol = {}
|
||||
for symbol, id_ in symbol_to_id.items():
|
||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
|
||||
outgoings_ct = defaultdict(int)
|
||||
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
state_to_jump_forward = {}
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
# Arbitrarily symbol cannot be recognized as jump forward
|
||||
continue
|
||||
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
if len(c) > 1:
|
||||
# Skip byte level transitions like c = "5E"
|
||||
continue
|
||||
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
|
||||
state_to_jump_forward[state] = JumpEdge(
|
||||
symbol=c,
|
||||
symbol_next_state=next_state,
|
||||
)
|
||||
|
||||
# Process the byte level jump forward
|
||||
outgoings_ct = defaultdict(int)
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
continue
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
byte_ = None
|
||||
if len(c) == 1 and ord(c) < 0x80:
|
||||
# ASCII character
|
||||
byte_ = ord(c)
|
||||
elif len(c) > 1:
|
||||
# FIXME: This logic is due to the leading \x00
|
||||
# https://github.com/outlines-dev/outlines/pull/930
|
||||
byte_ = int(symbols[0][1:], 16)
|
||||
|
||||
if byte_ is not None:
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
e = state_to_jump_forward.get(state, JumpEdge())
|
||||
e.byte = byte_
|
||||
e.byte_next_state = next_state
|
||||
state_to_jump_forward[state] = e
|
||||
|
||||
return state_to_jump_forward
|
||||
|
||||
|
||||
class OutlinesJumpForwardMap:
|
||||
def __init__(self, regex_string):
|
||||
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
|
||||
|
||||
def jump_forward_symbol(self, state):
|
||||
jump_forward_str = ""
|
||||
next_state = state
|
||||
while state in self.state_to_jump_forward:
|
||||
e = self.state_to_jump_forward[state]
|
||||
if e.symbol is None:
|
||||
break
|
||||
jump_forward_str += e.symbol
|
||||
next_state = e.symbol_next_state
|
||||
state = next_state
|
||||
|
||||
return jump_forward_str, next_state
|
||||
|
||||
def jump_forward_byte(self, state):
|
||||
if state not in self.state_to_jump_forward:
|
||||
return None
|
||||
|
||||
jump_forward_bytes = []
|
||||
next_state = None
|
||||
while state in self.state_to_jump_forward:
|
||||
e = self.state_to_jump_forward[state]
|
||||
assert e.byte is not None and e.byte_next_state is not None
|
||||
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
||||
next_state = e.byte_next_state
|
||||
state = next_state
|
||||
|
||||
return jump_forward_bytes
|
||||
|
||||
def is_jump_forward_symbol_state(self, state):
|
||||
return (
|
||||
state in self.state_to_jump_forward
|
||||
and self.state_to_jump_forward[state].symbol is not None
|
||||
)
|
||||
|
||||
|
||||
def test_main(regex_string):
|
||||
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
||||
for state, e in jump_forward_map.state_to_jump_forward.items():
|
||||
if e.symbol is not None:
|
||||
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
||||
print(f"{state} -> {next_state}", jump_forward_str)
|
||||
bytes_ = jump_forward_map.jump_forward_byte(state)
|
||||
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import outlines
|
||||
|
||||
outlines.caching.clear_cache()
|
||||
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
||||
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
||||
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
||||
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
||||
|
||||
test_main(r"[-+]?[0-9]+[ ]*")
|
||||
90
python/sglang/srt/constrained/reasoner_grammar_backend.py
Normal file
90
python/sglang/srt/constrained/reasoner_grammar_backend.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
|
||||
|
||||
|
||||
class ReasonerGrammarObject(BaseGrammarObject):
|
||||
def __init__(self, grammar: BaseGrammarObject, think_end_id):
|
||||
super().__init__()
|
||||
self.grammar = grammar
|
||||
self.think_end_id = think_end_id
|
||||
self.is_in_reasoning = True
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if token == self.think_end_id:
|
||||
self.is_in_reasoning = False
|
||||
|
||||
if not self.is_in_reasoning and token != self.think_end_id:
|
||||
self.grammar.accept_token(token)
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
if not self.is_in_reasoning:
|
||||
self.grammar.fill_vocab_mask(vocab_mask, idx)
|
||||
|
||||
def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return self.grammar.move_vocab_mask(vocab_mask, device)
|
||||
|
||||
@property
|
||||
def apply_vocab_mask(self):
|
||||
return self.grammar.apply_vocab_mask
|
||||
|
||||
def copy(self) -> BaseGrammarObject:
|
||||
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
||||
|
||||
@property
|
||||
def finished(self):
|
||||
return self.grammar.finished
|
||||
|
||||
@finished.setter
|
||||
def finished(self, finished):
|
||||
self.grammar.finished = finished
|
||||
|
||||
def try_jump_forward(self, tokenizer):
|
||||
return self.grammar.try_jump_forward(tokenizer)
|
||||
|
||||
def jump_forward_str_state(self, helper):
|
||||
return self.grammar.jump_forward_str_state(helper)
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
return self.grammar.jump_and_retokenize(
|
||||
old_output_ids, new_output_ids, next_state
|
||||
)
|
||||
|
||||
|
||||
class ReasonerGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
|
||||
super().__init__()
|
||||
self.grammar_backend = grammar_backend
|
||||
self.think_end_id = think_end_id
|
||||
|
||||
def _init_value_dispatch(
|
||||
self, key: Tuple[str, str]
|
||||
) -> Optional[ReasonerGrammarObject]:
|
||||
ret = self.grammar_backend._init_value_dispatch(key)
|
||||
if ret is None:
|
||||
return None
|
||||
return ReasonerGrammarObject(ret, self.think_end_id)
|
||||
141
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
Normal file
141
python/sglang/srt/constrained/triton_ops/bitmask_ops.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Adapt from
|
||||
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import get_device_core_count
|
||||
|
||||
|
||||
@triton.jit
|
||||
def apply_token_bitmask_inplace_kernel(
|
||||
logits_ptr,
|
||||
bitmask_ptr,
|
||||
indices_ptr,
|
||||
num_rows,
|
||||
vocab_size,
|
||||
logits_strides,
|
||||
bitmask_strides,
|
||||
NUM_SMS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
|
||||
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
|
||||
the masked logits will be set to -inf.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits_ptr : tl.tensor
|
||||
Pointer to the logits tensor to apply the bitmask to.
|
||||
|
||||
bitmask_ptr : tl.tensor
|
||||
Pointer to the bitmask tensor to apply.
|
||||
|
||||
indices_ptr : Optional[tl.tensor]
|
||||
Optional pointer to indices tensor specifying which rows to apply the mask to.
|
||||
|
||||
num_rows : int
|
||||
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.
|
||||
|
||||
vocab_size : int
|
||||
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
|
||||
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.
|
||||
|
||||
logits_strides : int
|
||||
Stride between rows in the logits tensor.
|
||||
|
||||
bitmask_strides : int
|
||||
Stride between rows in the bitmask tensor.
|
||||
|
||||
NUM_SMS : int
|
||||
Number of streaming multiprocessors to use.
|
||||
|
||||
BLOCK_SIZE : int
|
||||
Size of processing blocks.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
|
||||
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
|
||||
row_id = work_id // num_blocks
|
||||
block_offset = (work_id % num_blocks) * BLOCK_SIZE
|
||||
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
|
||||
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
|
||||
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
vocab_mask = offsets < vocab_size
|
||||
packed_bitmask_mask = bitmask_offsets < bitmask_strides
|
||||
packed_bitmask = tl.load(
|
||||
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
|
||||
packed_bitmask_mask,
|
||||
)
|
||||
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
||||
bitmask = bitmask.reshape(BLOCK_SIZE)
|
||||
|
||||
tl.store(
|
||||
logits_ptr + batch_id * logits_strides + offsets,
|
||||
-float("inf"),
|
||||
vocab_mask & bitmask,
|
||||
)
|
||||
|
||||
|
||||
def apply_token_bitmask_inplace_triton(
|
||||
logits: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
indices: Optional[Union[List[int], torch.Tensor]] = None,
|
||||
):
|
||||
NUM_SMS = get_device_core_count()
|
||||
BLOCK_SIZE = 4096
|
||||
BITS_PER_BLOCK = 32
|
||||
|
||||
# Check input dtype
|
||||
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"
|
||||
|
||||
# Check input tensor shapes.
|
||||
logits_shape = logits.shape
|
||||
bitmask_shape = bitmask.shape
|
||||
if logits.ndim == 1:
|
||||
logits_shape = (1, logits_shape[0])
|
||||
if bitmask.ndim == 1:
|
||||
bitmask_shape = (1, bitmask_shape[0])
|
||||
|
||||
required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
|
||||
assert required_bitmask_width >= bitmask_shape[1], (
|
||||
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
|
||||
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
|
||||
)
|
||||
|
||||
vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)
|
||||
|
||||
num_rows = None
|
||||
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
|
||||
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
|
||||
num_rows = indices.shape[0]
|
||||
else:
|
||||
assert (
|
||||
logits_shape[0] == bitmask_shape[0]
|
||||
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
|
||||
num_rows = logits_shape[0]
|
||||
|
||||
if NUM_SMS > 0:
|
||||
grid = (NUM_SMS,)
|
||||
else:
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
grid = (num_rows * num_blocks,)
|
||||
NUM_SMS = triton.next_power_of_2(grid[0])
|
||||
|
||||
apply_token_bitmask_inplace_kernel[grid](
|
||||
logits,
|
||||
bitmask,
|
||||
indices,
|
||||
num_rows,
|
||||
vocab_size,
|
||||
logits_shape[1],
|
||||
bitmask_shape[1],
|
||||
NUM_SMS,
|
||||
BLOCK_SIZE,
|
||||
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
|
||||
num_stages=3,
|
||||
)
|
||||
239
python/sglang/srt/constrained/xgrammar_backend.py
Normal file
239
python/sglang/srt/constrained/xgrammar_backend.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Constrained decoding with xgrammar backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from xgrammar import (
|
||||
CompiledGrammar,
|
||||
GrammarCompiler,
|
||||
GrammarMatcher,
|
||||
StructuralTagItem,
|
||||
TokenizerInfo,
|
||||
allocate_token_bitmask,
|
||||
)
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
if _is_hip:
|
||||
from sgl_kernel import apply_token_bitmask_inplace_cuda
|
||||
else:
|
||||
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
||||
apply_token_bitmask_inplace_triton,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_ROLLBACK_TOKENS = 200
|
||||
|
||||
|
||||
class XGrammarGrammar(BaseGrammarObject):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matcher: GrammarMatcher,
|
||||
vocab_size: int,
|
||||
ctx: CompiledGrammar,
|
||||
override_stop_tokens: Optional[Union[List[int], int]],
|
||||
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
||||
) -> None:
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx = ctx
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
self.finished = False
|
||||
self.accepted_tokens = []
|
||||
self.key_string = key_string
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if not self.is_terminated():
|
||||
accepted = self.matcher.accept_token(token)
|
||||
if not accepted:
|
||||
# log for debugging
|
||||
raise ValueError(
|
||||
f"Tokens not accepted: {token}\n"
|
||||
f"Accepted tokens: {self.accepted_tokens}\n"
|
||||
f"Key string: {self.key_string}"
|
||||
)
|
||||
else:
|
||||
self.accepted_tokens.append(token)
|
||||
|
||||
def rollback(self, k: int):
|
||||
self.matcher.rollback(k)
|
||||
self.accepted_tokens = self.accepted_tokens[:-k]
|
||||
|
||||
def is_terminated(self):
|
||||
return self.matcher.is_terminated()
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
return allocate_token_bitmask(batch_size, vocab_size)
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask.to(device, non_blocking=True)
|
||||
|
||||
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
if logits.device.type == "cuda":
|
||||
if _is_hip:
|
||||
apply_token_bitmask_inplace_cuda(logits, vocab_mask)
|
||||
else:
|
||||
apply_token_bitmask_inplace_triton(logits, vocab_mask)
|
||||
elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
|
||||
self.apply_vocab_mask_cpu(logits, vocab_mask)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported device: {logits.device.type}")
|
||||
|
||||
def copy(self):
|
||||
matcher = GrammarMatcher(
|
||||
self.ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher,
|
||||
self.vocab_size,
|
||||
self.ctx,
|
||||
self.override_stop_tokens,
|
||||
self.key_string,
|
||||
)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
s = self.matcher.find_jump_forward_string()
|
||||
if s:
|
||||
return [], s
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, data = helper
|
||||
return data, -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == new_output_ids[i]:
|
||||
k = i + 1
|
||||
else:
|
||||
break
|
||||
|
||||
# rollback to the last token that is the same
|
||||
if k < len(old_output_ids):
|
||||
self.matcher.rollback(len(old_output_ids) - k)
|
||||
|
||||
for i in range(k, len(new_output_ids)):
|
||||
assert self.matcher.accept_token(new_output_ids[i])
|
||||
|
||||
def __repr__(self):
|
||||
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
|
||||
|
||||
|
||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
vocab_size: int,
|
||||
model_eos_token_ids: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(tokenizer, "init_xgrammar"):
|
||||
# For special tokenizer
|
||||
tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
|
||||
else:
|
||||
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
|
||||
# This ensures consistency between what the model considers EOS and what XGrammar uses
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
|
||||
)
|
||||
override_stop_tokens = None
|
||||
|
||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||
self.vocab_size = vocab_size
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
|
||||
matcher = GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
|
||||
)
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
if key_string == "$$ANY$$":
|
||||
# Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
|
||||
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
||||
else:
|
||||
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
||||
|
||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(key_string)
|
||||
except RuntimeError as e:
|
||||
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_regex(key_string)
|
||||
except RuntimeError as e:
|
||||
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
structural_tag = json.loads(key_string)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
ctx = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
except (RuntimeError, json.decoder.JSONDecodeError) as e:
|
||||
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
|
||||
return INVALID_GRAMMAR_OBJ
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_compiler.clear_cache()
|
||||
102
python/sglang/srt/custom_op.py
Normal file
102
python/sglang/srt/custom_op.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
is_xpu,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_cpu = is_cpu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_npu = is_npu()
|
||||
_is_xpu = is_xpu()
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._forward_method = self.dispatch_forward()
|
||||
|
||||
# States for torch.compile
|
||||
self._original_forward_method = None
|
||||
self.is_torch_compile = False
|
||||
|
||||
def enter_torch_compile(self, num_tokens: int):
|
||||
# Skip if Op is already entered compile mode.
|
||||
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
|
||||
# among layers and `enter_torch_compile` will be called many times.
|
||||
# We should prevent `self._original_forward_method` from being overridden when
|
||||
# it is not the first time `enter_torch_compile` called.
|
||||
if self.is_torch_compile:
|
||||
return
|
||||
|
||||
self._original_forward_method = self._forward_method
|
||||
# NOTE: Temporarily workaround MoE
|
||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||
# so we decide to only use torch.compile when bs=1
|
||||
if "FusedMoE" in self.__class__.__name__:
|
||||
if num_tokens == 1:
|
||||
from sglang.srt.layers.moe.fused_moe_native import (
|
||||
fused_moe_forward_native,
|
||||
)
|
||||
|
||||
self._forward_method = fused_moe_forward_native
|
||||
elif "TopK" in self.__class__.__name__:
|
||||
if num_tokens == 1:
|
||||
self._forward_method = self.forward_native
|
||||
else:
|
||||
self._forward_method = self.forward_native
|
||||
self.is_torch_compile = True
|
||||
|
||||
def leave_torch_compile(self):
|
||||
# Skip if Op is already exited compile mode.
|
||||
if not self.is_torch_compile:
|
||||
return
|
||||
|
||||
self._forward_method = self._original_forward_method
|
||||
self._original_forward_method = None
|
||||
self.is_torch_compile = False
|
||||
|
||||
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_cuda(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_npu(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_xpu(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_hpu(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def dispatch_forward(self):
|
||||
if _is_cuda:
|
||||
return self.forward_cuda
|
||||
elif _is_hip:
|
||||
return self.forward_hip
|
||||
elif _is_cpu and _is_cpu_amx_available:
|
||||
return self.forward_cpu
|
||||
elif _is_npu:
|
||||
return self.forward_npu
|
||||
elif _is_xpu:
|
||||
return self.forward_xpu
|
||||
else:
|
||||
return self.forward_native
|
||||
0
python/sglang/srt/debug_utils/__init__.py
Normal file
0
python/sglang/srt/debug_utils/__init__.py
Normal file
168
python/sglang/srt/debug_utils/dump_comparator.py
Normal file
168
python/sglang/srt/debug_utils/dump_comparator.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import argparse
|
||||
import functools
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
import torch
|
||||
|
||||
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
|
||||
from sglang.srt.debug_utils.dumper import get_truncated_value
|
||||
|
||||
|
||||
def main(args):
|
||||
df_target = read_meta(args.target_path)
|
||||
df_target = df_target.sort("rank", "dump_index")
|
||||
df_target = df_target.filter(
|
||||
(pl.col("forward_pass_id") >= args.start_id)
|
||||
& (pl.col("forward_pass_id") <= args.end_id)
|
||||
)
|
||||
assert all(
|
||||
c in df_target.columns
|
||||
for c in ["rank", "forward_pass_id", "dump_index", "name"]
|
||||
)
|
||||
|
||||
df_baseline = read_meta(args.baseline_path)
|
||||
print("df_target", df_target)
|
||||
print("df_baseline", df_baseline)
|
||||
|
||||
for row in df_target.iter_rows(named=True):
|
||||
path_target = Path(args.target_path) / row["filename"]
|
||||
|
||||
row_baseline = find_row(
|
||||
df_baseline,
|
||||
conditions=dict(
|
||||
forward_pass_id=row["forward_pass_id"]
|
||||
- args.start_id
|
||||
+ args.baseline_start_id,
|
||||
**{
|
||||
k: v
|
||||
for k, v in row.items()
|
||||
if k not in ["forward_pass_id", "dump_index", "filename"]
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
if row_baseline is None:
|
||||
print(f"Skip: target={str(path_target)} since no baseline")
|
||||
x_target = _load_object(path_target)
|
||||
if x_target is not None:
|
||||
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
||||
continue
|
||||
|
||||
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
|
||||
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
|
||||
check_tensor_pair(
|
||||
path_baseline=path_baseline, path_target=path_target, name=row["name"]
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
def check_tensor_pair(path_baseline, path_target, name=""):
|
||||
x_baseline = _load_object(path_baseline)
|
||||
x_target = _load_object(path_target)
|
||||
|
||||
print(
|
||||
f"Raw "
|
||||
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
||||
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
||||
)
|
||||
|
||||
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
|
||||
x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
|
||||
|
||||
print(
|
||||
f"After preprocessor "
|
||||
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
||||
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
||||
)
|
||||
|
||||
x_target = x_target.float()
|
||||
x_baseline = x_baseline.float()
|
||||
|
||||
for name, fn in (
|
||||
("mean", torch.mean),
|
||||
("std", torch.std),
|
||||
("min", torch.min),
|
||||
("max", torch.max),
|
||||
("p1", functools.partial(torch.quantile, q=0.01)),
|
||||
("p5", functools.partial(torch.quantile, q=0.05)),
|
||||
("p95", functools.partial(torch.quantile, q=0.95)),
|
||||
("p99", functools.partial(torch.quantile, q=0.99)),
|
||||
):
|
||||
value_baseline = fn(x_baseline).item()
|
||||
value_target = fn(x_target).item()
|
||||
print(
|
||||
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
|
||||
)
|
||||
|
||||
if x_baseline.shape != x_target.shape:
|
||||
print(f"⚠️ Shape mismatch")
|
||||
return
|
||||
|
||||
raw_abs_diff = (x_target - x_baseline).abs()
|
||||
|
||||
max_abs_diff = raw_abs_diff.max().item()
|
||||
mean_abs_diff = raw_abs_diff.mean().item()
|
||||
rel_diff = _calc_rel_diff(x_target, x_baseline)
|
||||
|
||||
needs_print = max_abs_diff > 1e-3
|
||||
|
||||
print(
|
||||
"\t".join(
|
||||
f"{'❌' if value > 1e-3 else '✅'} {name}={value}"
|
||||
for name, value in [
|
||||
("rel_diff", rel_diff),
|
||||
("max_abs_diff", max_abs_diff),
|
||||
("mean_abs_diff", mean_abs_diff),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if needs_print:
|
||||
print(f"x_baseline(sample)={get_truncated_value(x_baseline)}")
|
||||
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
||||
|
||||
|
||||
def _try_unify_shape(x: torch.Tensor, target_shape):
|
||||
x_shape = x.shape
|
||||
num_dim_to_remove = len(x_shape) - len(target_shape)
|
||||
if (x_shape[num_dim_to_remove:] == target_shape) and all(
|
||||
val == 1 for val in x_shape[:num_dim_to_remove]
|
||||
):
|
||||
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
|
||||
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
|
||||
return out
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Copied from DeepGEMM
|
||||
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def _comparison_preprocessor(x_baseline, x_target, name):
|
||||
# can insert arbitrary adhoc postprocessing logic here
|
||||
return x_baseline, x_target
|
||||
|
||||
|
||||
def _load_object(path):
|
||||
x = torch.load(path, weights_only=False)
|
||||
if not isinstance(x, torch.Tensor):
|
||||
print(f"Skip load {path} since {type(x)=} is not a Tensor")
|
||||
return None
|
||||
return x.cuda()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--baseline-path", type=str)
|
||||
parser.add_argument("--target-path", type=str)
|
||||
parser.add_argument("--start-id", type=int, default=0)
|
||||
parser.add_argument("--end-id", type=int, default=1000000)
|
||||
parser.add_argument("--baseline-start-id", type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
97
python/sglang/srt/debug_utils/dump_loader.py
Normal file
97
python/sglang/srt/debug_utils/dump_loader.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import functools
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import polars as pl
|
||||
import torch
|
||||
|
||||
|
||||
class DumpLoader:
|
||||
def __init__(self):
|
||||
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
|
||||
|
||||
self._enable = directory is not None
|
||||
if self._enable:
|
||||
self._directory = Path(directory)
|
||||
self._df = read_meta(directory)
|
||||
|
||||
@property
|
||||
def enable(self):
|
||||
return self._enable
|
||||
|
||||
def load(self, name, **kwargs):
|
||||
assert self._enable, "Please call DumpLoader.load only when it is enabled"
|
||||
|
||||
from sglang.srt.debug_utils.dumper import dumper
|
||||
|
||||
forward_pass_id = dumper._forward_pass_id
|
||||
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
|
||||
row = find_row(self._df, conditions=conditions)
|
||||
assert (
|
||||
row is not None
|
||||
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
|
||||
|
||||
path = self._directory / row["filename"]
|
||||
output = torch.load(path, weights_only=False)
|
||||
|
||||
print(
|
||||
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def read_meta(directory):
|
||||
directory = Path(directory)
|
||||
assert directory.is_dir(), f"{directory=} should be a directory"
|
||||
|
||||
rows = []
|
||||
for p in directory.glob("*.pt"):
|
||||
full_kwargs = {}
|
||||
for kv in p.stem.split("___"):
|
||||
k, v = kv.split("=")
|
||||
full_kwargs[k] = v
|
||||
rows.append(
|
||||
{
|
||||
"filename": str(p.name),
|
||||
**full_kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
df = pl.DataFrame(rows)
|
||||
df = df.with_columns(
|
||||
pl.col("forward_pass_id").cast(int),
|
||||
pl.col("rank").cast(int),
|
||||
pl.col("dump_index").cast(int),
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def find_row(df, conditions: Dict[str, Any]):
|
||||
df_sub = df.filter(
|
||||
functools.reduce(
|
||||
lambda a, b: a & b,
|
||||
[
|
||||
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
|
||||
for col in conditions.keys()
|
||||
],
|
||||
)
|
||||
)
|
||||
assert len(df_sub) <= 1
|
||||
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
|
||||
|
||||
|
||||
def _cast_to_polars_dtype(value, target_dtype):
|
||||
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
|
||||
return int(value)
|
||||
elif target_dtype in (pl.Float64, pl.Float32):
|
||||
return float(value)
|
||||
elif target_dtype == pl.Boolean:
|
||||
return bool(value)
|
||||
elif target_dtype == pl.String:
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
dump_loader = DumpLoader()
|
||||
116
python/sglang/srt/debug_utils/dumper.py
Normal file
116
python/sglang/srt/debug_utils/dumper.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class _Dumper:
|
||||
"""Utility to dump tensors, which can be useful when comparison checking models.
|
||||
|
||||
Example usage:
|
||||
dumper.on_forward_pass_start()
|
||||
dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id)
|
||||
|
||||
Import from non-SGLang system:
|
||||
```
|
||||
import sys
|
||||
sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils")
|
||||
from dumper import dumper
|
||||
```
|
||||
|
||||
Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Do not import `sglang` to make this file standalone
|
||||
self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1")))
|
||||
self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
|
||||
self._enable_write_file = bool(
|
||||
int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1"))
|
||||
)
|
||||
self._partial_name: Optional[str] = None
|
||||
self._dump_index = 0
|
||||
self._forward_pass_id = 0
|
||||
|
||||
def on_forward_pass_start(self):
|
||||
self._forward_pass_id += 1
|
||||
print(
|
||||
f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
|
||||
)
|
||||
|
||||
def dump(self, name, value, **kwargs):
|
||||
if not self._enable:
|
||||
return
|
||||
|
||||
assert (
|
||||
self._forward_pass_id >= 1
|
||||
), "Do you forget to call `dumper.on_forward_pass_start()`?"
|
||||
self._dump_index += 1
|
||||
|
||||
if self._partial_name is None:
|
||||
self._partial_name = _get_partial_name()
|
||||
|
||||
rank = _get_rank()
|
||||
full_kwargs = dict(
|
||||
forward_pass_id=self._forward_pass_id,
|
||||
rank=rank,
|
||||
name=name,
|
||||
dump_index=self._dump_index,
|
||||
**kwargs,
|
||||
)
|
||||
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
|
||||
path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename
|
||||
|
||||
sample_value = get_truncated_value(value)
|
||||
|
||||
print(
|
||||
f"[Dumper] [{rank}, {time.time()}] {path} "
|
||||
f"type={type(value)} "
|
||||
f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
|
||||
f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
|
||||
f"sample_value={sample_value}"
|
||||
)
|
||||
|
||||
if self._enable_write_file:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(value, str(path))
|
||||
|
||||
|
||||
def _get_partial_name():
|
||||
rank = _get_rank()
|
||||
object_list = [str(time.time()) if rank == 0 else None]
|
||||
if dist.is_initialized():
|
||||
dist.broadcast_object_list(object_list, device="cuda")
|
||||
return object_list[0]
|
||||
|
||||
|
||||
def _get_rank():
|
||||
if dist.is_initialized():
|
||||
return dist.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_truncated_value(value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return [get_truncated_value(x) for x in value]
|
||||
|
||||
if not isinstance(value, torch.Tensor):
|
||||
return None
|
||||
|
||||
if value.numel() < 200:
|
||||
return value
|
||||
|
||||
slices = [
|
||||
slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
|
||||
]
|
||||
return value[tuple(slices)]
|
||||
|
||||
|
||||
dumper = _Dumper()
|
||||
234
python/sglang/srt/debug_utils/text_comparator.py
Normal file
234
python/sglang/srt/debug_utils/text_comparator.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
|
||||
_DESCRIPTION = """Compare and find differences to benchmark outputs.
|
||||
|
||||
Supported inputs:
|
||||
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
|
||||
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
|
||||
"""
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.data_type == "simple_evals":
|
||||
df_input = _compute_df_input_mode_simple_evals(args)
|
||||
else:
|
||||
df_input = _transform_df_input(_compute_df_raw(args))
|
||||
|
||||
assert all(
|
||||
c in df_input.columns
|
||||
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
|
||||
)
|
||||
|
||||
df_meta = _compute_df_meta(df_input)
|
||||
|
||||
df_correctness_per_trial = df_input.group_by(
|
||||
"category", "trial_index", maintain_order=True
|
||||
).agg(pl.col("correct").mean())
|
||||
df_correctness_delta = (
|
||||
df_meta.group_by("correctness_delta").len().sort("correctness_delta")
|
||||
)
|
||||
df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
|
||||
df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)
|
||||
|
||||
print(f"Dump output to {args.output_path}")
|
||||
Path(args.output_path).write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
df_meta=df_meta.to_dicts(),
|
||||
df_good_to_bad=df_good_to_bad.to_dicts(),
|
||||
df_bad_to_good=df_bad_to_good.to_dicts(),
|
||||
),
|
||||
indent=4,
|
||||
),
|
||||
)
|
||||
|
||||
if not args.disable_print_details:
|
||||
with pl.Config(
|
||||
fmt_str_lengths=10000,
|
||||
tbl_cols=-1,
|
||||
tbl_rows=-1,
|
||||
tbl_width_chars=-1,
|
||||
tbl_formatting="UTF8_FULL",
|
||||
):
|
||||
print("====== Correctness per trial ======")
|
||||
print(df_correctness_per_trial)
|
||||
|
||||
print(
|
||||
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
|
||||
)
|
||||
print(df_correctness_delta)
|
||||
|
||||
for name, df in [
|
||||
("Good->Bad", df_good_to_bad),
|
||||
("Bad->Good", df_bad_to_good),
|
||||
]:
|
||||
print(f"====== Concrete Examples: {name} ======")
|
||||
print(df)
|
||||
|
||||
|
||||
def _compute_df_input_mode_simple_evals(args):
|
||||
return pl.concat(
|
||||
[
|
||||
_compute_df_input_one_mode_simple_evals(**info)
|
||||
for info in _get_file_infos(args=args)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
|
||||
data = json.loads(Path(path).read_text())
|
||||
rows = []
|
||||
|
||||
for single_eval_result in data["metadata"]["single_eval_results"]:
|
||||
prompt = single_eval_result["example_level_metadata"][
|
||||
"actual_queried_prompt_messages"
|
||||
]
|
||||
score = single_eval_result["score"]
|
||||
assert score in {0.0, 1.0}, f"{score=}"
|
||||
|
||||
row = dict(
|
||||
category=category,
|
||||
trial_index=trial_index,
|
||||
prompt_id=_compute_id_from_object(prompt),
|
||||
prompt=json.dumps(prompt),
|
||||
output=single_eval_result["example_level_metadata"]["response_text"],
|
||||
correct=score == 1.0,
|
||||
)
|
||||
rows.append(row)
|
||||
|
||||
return pl.DataFrame(rows)
|
||||
|
||||
|
||||
def _compute_id_from_object(obj):
|
||||
if isinstance(obj, pl.Series):
|
||||
obj = obj.to_list()
|
||||
json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _compute_df_raw(args):
|
||||
return pl.concat(
|
||||
[
|
||||
_read_df_raw(
|
||||
path=info["path"],
|
||||
category=info["category"],
|
||||
trial_index=info["trial_index"],
|
||||
)
|
||||
for info in _get_file_infos(args=args)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_file_infos(args):
|
||||
return [
|
||||
dict(path=path, category=category, trial_index=trial_index)
|
||||
for category, paths in [
|
||||
("baseline", args.baseline_path),
|
||||
("target", args.target_path),
|
||||
]
|
||||
for trial_index, path in enumerate(paths)
|
||||
]
|
||||
|
||||
|
||||
def _read_df_raw(path: str, category: str, trial_index: int):
|
||||
return pl.read_ndjson(path).with_columns(
|
||||
category=pl.lit(category), trial_index=trial_index
|
||||
)
|
||||
|
||||
|
||||
def _transform_df_input(df: pl.DataFrame):
|
||||
if "doc_id" in df.columns:
|
||||
print("Transform mode: lm_eval")
|
||||
|
||||
filter_names = df["filter"].unique(maintain_order=True).to_list()
|
||||
if len(filter_names) > 1:
|
||||
filter_name = filter_names[0]
|
||||
print(f"Choose {filter_name=} among {filter_names}")
|
||||
df = df.filter(pl.col("filter") == filter_name)
|
||||
|
||||
df = df.select(
|
||||
pl.col("category"),
|
||||
pl.col("trial_index"),
|
||||
prompt_id=pl.col("doc_id"),
|
||||
prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
|
||||
output=pl.col("resps").list.get(0).list.get(0),
|
||||
correct=pl.col("exact_match").cast(bool),
|
||||
)
|
||||
|
||||
return df
|
||||
elif "prompt_id" in df.columns:
|
||||
print("Transform mode: SGLang bench")
|
||||
return df
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
|
||||
)
|
||||
|
||||
|
||||
def _compute_df_meta(df_input: pl.DataFrame):
|
||||
df_input = df_input.sort("prompt_id", "category", "trial_index")
|
||||
df_meta = pl.DataFrame(
|
||||
[
|
||||
_handle_one_prompt(df_one_prompt)
|
||||
for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
|
||||
]
|
||||
)
|
||||
df_meta = df_meta.with_columns(
|
||||
correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
|
||||
)
|
||||
df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
|
||||
return df_meta
|
||||
|
||||
|
||||
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
|
||||
assert (
|
||||
len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
|
||||
)
|
||||
|
||||
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
|
||||
df_target = df_one_prompt.filter(pl.col("category") == "target")
|
||||
|
||||
outputs_baseline = df_baseline["output"].to_list()
|
||||
outputs_target = df_target["output"].to_list()
|
||||
|
||||
output_same_prefix_len = max(
|
||||
_compute_str_prefix_len(output_baseline, output_target)
|
||||
for output_baseline in outputs_baseline
|
||||
for output_target in outputs_target
|
||||
)
|
||||
|
||||
return dict(
|
||||
prompt_id=df_one_prompt[0, "prompt_id"],
|
||||
correctness_baseline=df_baseline["correct"].mean(),
|
||||
correctness_target=df_target["correct"].mean(),
|
||||
output_same_prefix_len=output_same_prefix_len,
|
||||
prompt=df_one_prompt[0, "prompt"],
|
||||
outputs_baseline=outputs_baseline,
|
||||
outputs_target=outputs_target,
|
||||
)
|
||||
|
||||
|
||||
def _compute_str_prefix_len(a: str, b: str) -> int:
|
||||
min_len = min(len(a), len(b))
|
||||
for i in range(min_len):
|
||||
if a[i] != b[i]:
|
||||
return i
|
||||
return min_len
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=_DESCRIPTION)
|
||||
parser.add_argument("--data-type", type=str, default="auto")
|
||||
parser.add_argument("--baseline-path", type=str, nargs="+")
|
||||
parser.add_argument("--target-path", type=str, nargs="+")
|
||||
parser.add_argument(
|
||||
"--output-path", type=str, default="/tmp/text_comparator_output.json"
|
||||
)
|
||||
parser.add_argument("--disable-print-details", action="store_true")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
6
python/sglang/srt/disaggregation/ascend/__init__.py
Normal file
6
python/sglang/srt/disaggregation/ascend/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sglang.srt.disaggregation.ascend.conn import (
|
||||
AscendKVBootstrapServer,
|
||||
AscendKVManager,
|
||||
AscendKVReceiver,
|
||||
AscendKVSender,
|
||||
)
|
||||
117
python/sglang/srt/disaggregation/ascend/conn.py
Normal file
117
python/sglang/srt/disaggregation/ascend/conn.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
||||
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||
from sglang.srt.disaggregation.mooncake.conn import (
|
||||
MooncakeKVBootstrapServer,
|
||||
MooncakeKVManager,
|
||||
MooncakeKVReceiver,
|
||||
MooncakeKVSender,
|
||||
)
|
||||
from sglang.srt.utils import get_local_ip_by_remote
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AscendKVManager(MooncakeKVManager):
|
||||
def init_engine(self):
|
||||
# TransferEngine initialized on ascend.
|
||||
local_ip = get_local_ip_by_remote()
|
||||
self.engine = AscendTransferEngine(
|
||||
hostname=local_ip,
|
||||
npu_id=self.kv_args.gpu_id,
|
||||
disaggregation_mode=self.disaggregation_mode,
|
||||
)
|
||||
|
||||
def register_buffer_to_engine(self):
|
||||
self.engine.batch_register(self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens)
|
||||
# The Ascend backend optimize batch registration for small memory blocks.
|
||||
self.engine.batch_register(
|
||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||
)
|
||||
|
||||
def send_kvcache(
|
||||
self,
|
||||
mooncake_session_id: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
):
|
||||
# Group by indices
|
||||
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||
prefill_kv_indices, dst_kv_indices
|
||||
)
|
||||
|
||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||
layers_params = [
|
||||
(
|
||||
self.kv_args.kv_data_ptrs[layer_id],
|
||||
dst_kv_ptrs[layer_id],
|
||||
self.kv_args.kv_item_lens[layer_id],
|
||||
)
|
||||
for layer_id in range(num_layers)
|
||||
]
|
||||
|
||||
def set_transfer_blocks(
|
||||
src_ptr: int, dst_ptr: int, item_len: int
|
||||
) -> List[Tuple[int, int, int]]:
|
||||
transfer_blocks = []
|
||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||
length = item_len * len(prefill_index)
|
||||
transfer_blocks.append((src_addr, dst_addr, length))
|
||||
return transfer_blocks
|
||||
|
||||
# Worker function for processing a single layer
|
||||
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
||||
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||
|
||||
# Worker function for processing all layers in a batch
|
||||
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
||||
transfer_blocks = []
|
||||
for src_ptr, dst_ptr, item_len in layers_params:
|
||||
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||
|
||||
if self.enable_custom_mem_pool:
|
||||
futures = [
|
||||
executor.submit(
|
||||
process_layer,
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
item_len,
|
||||
)
|
||||
for (src_ptr, dst_ptr, item_len) in layers_params
|
||||
]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
status = future.result()
|
||||
if status != 0:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
return status
|
||||
else:
|
||||
# Combining all layers' params in one batch transfer is more efficient
|
||||
# compared to using multiple threads
|
||||
return process_layers(layers_params)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class AscendKVSender(MooncakeKVSender):
|
||||
pass
|
||||
|
||||
|
||||
class AscendKVReceiver(MooncakeKVReceiver):
|
||||
pass
|
||||
|
||||
|
||||
class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
|
||||
pass
|
||||
58
python/sglang/srt/disaggregation/ascend/transfer_engine.py
Normal file
58
python/sglang/srt/disaggregation/ascend/transfer_engine.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AscendTransferEngine(MooncakeTransferEngine):
|
||||
|
||||
def __init__(
|
||||
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
|
||||
):
|
||||
try:
|
||||
from mf_adapter import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
|
||||
) from e
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = hostname
|
||||
self.npu_id = npu_id
|
||||
|
||||
# Centralized storage address of the AscendTransferEngine
|
||||
self.store_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
if disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.role = "Prefill"
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.role = "Decode"
|
||||
else:
|
||||
logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}")
|
||||
raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}")
|
||||
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
||||
self.initialize()
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the ascend transfer instance."""
|
||||
ret_value = self.engine.initialize(
|
||||
self.store_url,
|
||||
self.session_id,
|
||||
self.role,
|
||||
self.npu_id,
|
||||
)
|
||||
if ret_value != 0:
|
||||
logger.error("Ascend Transfer Engine initialization failed.")
|
||||
raise RuntimeError("Ascend Transfer Engine initialization failed.")
|
||||
|
||||
def batch_register(self, ptrs: List[int], lengths: List[int]):
|
||||
try:
|
||||
ret_value = self.engine.batch_register_memory(ptrs, lengths)
|
||||
except Exception:
|
||||
# Mark register as failed
|
||||
ret_value = -1
|
||||
if ret_value != 0:
|
||||
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
|
||||
8
python/sglang/srt/disaggregation/base/__init__.py
Normal file
8
python/sglang/srt/disaggregation/base/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from sglang.srt.disaggregation.base.conn import (
|
||||
BaseKVBootstrapServer,
|
||||
BaseKVManager,
|
||||
BaseKVReceiver,
|
||||
BaseKVSender,
|
||||
KVArgs,
|
||||
KVPoll,
|
||||
)
|
||||
134
python/sglang/srt/disaggregation/base/conn.py
Normal file
134
python/sglang/srt/disaggregation/base/conn.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
|
||||
|
||||
class KVArgs:
|
||||
engine_rank: int
|
||||
kv_data_ptrs: List[int]
|
||||
kv_data_lens: List[int]
|
||||
kv_item_lens: List[int]
|
||||
aux_data_ptrs: List[int]
|
||||
aux_data_lens: List[int]
|
||||
aux_item_lens: List[int]
|
||||
ib_device: str
|
||||
ib_traffic_class: str
|
||||
gpu_id: int
|
||||
# for different tp
|
||||
decode_tp_size: int
|
||||
kv_head_num: int
|
||||
page_size: int
|
||||
# for pp prefill
|
||||
prefill_pp_size: int
|
||||
pp_rank: int
|
||||
prefill_start_layer: int
|
||||
# for system dp
|
||||
system_dp_rank: int
|
||||
|
||||
|
||||
class KVPoll:
|
||||
Failed = 0
|
||||
Bootstrapping = 1
|
||||
WaitingForInput = 2
|
||||
Transferring = 3
|
||||
Success = 4
|
||||
|
||||
|
||||
class BaseKVManager(ABC):
|
||||
"""Base class for managing transfers states"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
args: KVArgs,
|
||||
disaggregation_mode: DisaggregationMode,
|
||||
server_args: ServerArgs,
|
||||
is_mla_backend: Optional[bool] = False,
|
||||
): ...
|
||||
|
||||
|
||||
class BaseKVSender(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: int,
|
||||
dest_tp_ranks: List[int],
|
||||
pp_rank: int,
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
||||
"""
|
||||
Notify the decoder server about the kv indices length and aux index
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send(self, kv_indices: npt.NDArray[np.int32]):
|
||||
"""
|
||||
Send the kv cache at the given kv indices to the decoder server
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def poll(self) -> KVPoll:
|
||||
"""
|
||||
Check the status of the kv cache transfer
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def failure_exception(self):
|
||||
"""
|
||||
Raise an exception if the kv cache transfer fails
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseKVReceiver(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
"""
|
||||
Notify the prefill server about the kv indices and aux index
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def poll(self) -> KVPoll:
|
||||
"""
|
||||
Check the status of the kv cache transfer
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def failure_exception(self):
|
||||
"""
|
||||
Raise an exception if the kv cache transfer fails
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseKVBootstrapServer(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, host: str, port: int): ...
|
||||
5
python/sglang/srt/disaggregation/common/__init__.py
Normal file
5
python/sglang/srt/disaggregation/common/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from sglang.srt.disaggregation.common.conn import (
|
||||
CommonKVBootstrapServer,
|
||||
CommonKVManager,
|
||||
CommonKVReceiver,
|
||||
)
|
||||
438
python/sglang/srt/disaggregation/common/conn.py
Normal file
438
python/sglang/srt/disaggregation/common/conn.py
Normal file
@@ -0,0 +1,438 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import threading
|
||||
from functools import cache
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import requests
|
||||
import zmq
|
||||
from aiohttp import web
|
||||
|
||||
from sglang.srt.disaggregation.base.conn import (
|
||||
BaseKVBootstrapServer,
|
||||
BaseKVManager,
|
||||
BaseKVReceiver,
|
||||
BaseKVSender,
|
||||
KVArgs,
|
||||
KVPoll,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
format_tcp_address,
|
||||
get_free_port,
|
||||
get_ip,
|
||||
get_local_ip_by_remote,
|
||||
is_valid_ipv6_address,
|
||||
maybe_wrap_ipv6_address,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommonKVManager(BaseKVManager):
|
||||
def __init__(
|
||||
self,
|
||||
args: KVArgs,
|
||||
disaggregation_mode: DisaggregationMode,
|
||||
server_args: ServerArgs,
|
||||
is_mla_backend: Optional[bool] = False,
|
||||
):
|
||||
self.kv_args = args
|
||||
self.is_mla_backend = is_mla_backend
|
||||
self.disaggregation_mode = disaggregation_mode
|
||||
# for p/d multi node infer
|
||||
self.bootstrap_host = server_args.host
|
||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||
self.dist_init_addr = server_args.dist_init_addr
|
||||
self.tp_size = server_args.tp_size
|
||||
self.dp_size = server_args.dp_size
|
||||
self.enable_dp_attention = server_args.enable_dp_attention
|
||||
if not server_args.enable_dp_attention and server_args.dp_size != 1:
|
||||
raise ValueError(
|
||||
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
|
||||
)
|
||||
|
||||
self.rank_port = get_free_port()
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self._register_to_bootstrap()
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||
self.prefill_tp_size_table: Dict[str, int] = {}
|
||||
self.prefill_dp_size_table: Dict[str, int] = {}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
||||
)
|
||||
|
||||
def _register_to_bootstrap(self):
|
||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||
if self.dist_init_addr:
|
||||
# multi node: bootstrap server's host is dist_init_addr
|
||||
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
||||
if self.dist_init_addr.endswith("]"):
|
||||
host = self.dist_init_addr
|
||||
else:
|
||||
host, _ = self.dist_init_addr.rsplit(":", 1)
|
||||
else:
|
||||
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
||||
else:
|
||||
# single node: bootstrap server's host is same as http server's host
|
||||
host = self.bootstrap_host
|
||||
host = maybe_wrap_ipv6_address(host)
|
||||
|
||||
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
||||
url = f"http://{bootstrap_server_url}/route"
|
||||
payload = {
|
||||
"role": "Prefill",
|
||||
"tp_size": self.tp_size,
|
||||
"dp_size": self.dp_size,
|
||||
"rank_ip": get_local_ip_by_remote(),
|
||||
"rank_port": self.rank_port,
|
||||
"engine_rank": self.kv_args.engine_rank,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.put(url, json=payload)
|
||||
if response.status_code == 200:
|
||||
logger.debug("Prefill successfully registered to bootstrap server.")
|
||||
else:
|
||||
logger.error(
|
||||
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
||||
|
||||
@cache
|
||||
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
||||
socket = zmq.Context().socket(zmq.PUSH)
|
||||
if is_ipv6:
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
socket.connect(endpoint)
|
||||
return socket
|
||||
|
||||
|
||||
class CommonKVReceiver(BaseKVReceiver):
|
||||
_ctx = zmq.Context()
|
||||
_socket_cache = {}
|
||||
_socket_locks = {}
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
prefill_dp_rank: Optional[int] = None,
|
||||
):
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.bootstrap_addr = bootstrap_addr
|
||||
self.kv_mgr = mgr
|
||||
|
||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||
self.prefill_tp_size, self.prefill_dp_size = (
|
||||
self._get_prefill_dp_size_from_server()
|
||||
)
|
||||
if self.prefill_tp_size is None or self.prefill_dp_size is None:
|
||||
logger.error(
|
||||
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
|
||||
)
|
||||
else:
|
||||
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
|
||||
self.prefill_tp_size
|
||||
)
|
||||
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
||||
self.prefill_dp_size
|
||||
)
|
||||
else:
|
||||
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
|
||||
self.bootstrap_addr
|
||||
]
|
||||
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
||||
self.bootstrap_addr
|
||||
]
|
||||
|
||||
# Currently, we don't allow prefill instance and decode instance to
|
||||
# have different TP sizes per DP rank, except for models using MLA.
|
||||
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
||||
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
|
||||
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
|
||||
self.target_tp_rank = (
|
||||
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||
)
|
||||
self.required_dst_info_num = 1
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
||||
self.target_tp_rank = (
|
||||
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
||||
self.required_dst_info_num = (
|
||||
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
||||
)
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
else:
|
||||
assert (
|
||||
self.kv_mgr.is_mla_backend
|
||||
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
||||
|
||||
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
|
||||
self.target_tp_ranks = [
|
||||
rank
|
||||
for rank in range(
|
||||
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
|
||||
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
||||
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
|
||||
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
|
||||
)
|
||||
]
|
||||
|
||||
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
|
||||
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
|
||||
# or the KVPoll will never be set correctly
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
|
||||
if prefill_dp_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
||||
self.prefill_dp_rank = prefill_dp_rank
|
||||
else:
|
||||
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
||||
self.target_dp_group = self.prefill_dp_rank
|
||||
|
||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||
bootstrap_key = (
|
||||
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
|
||||
)
|
||||
|
||||
if bootstrap_key not in self.kv_mgr.connection_pool:
|
||||
bootstrap_infos = []
|
||||
for target_tp_rank in self.target_tp_ranks:
|
||||
bootstrap_info = self._get_bootstrap_info_from_server(
|
||||
target_tp_rank,
|
||||
self.target_dp_group,
|
||||
)
|
||||
if bootstrap_info is not None:
|
||||
# NOTE: only support MLA for now: select one prefill rank as real rank
|
||||
bootstrap_info["is_dummy"] = not bool(
|
||||
target_tp_rank == self.target_tp_rank
|
||||
or self.target_tp_rank is None
|
||||
)
|
||||
bootstrap_infos.append(bootstrap_info)
|
||||
else:
|
||||
logger.error(
|
||||
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
|
||||
)
|
||||
self.bootstrap_infos = bootstrap_infos
|
||||
|
||||
if len(self.bootstrap_infos) == 0:
|
||||
logger.error(
|
||||
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||
)
|
||||
else:
|
||||
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
|
||||
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
||||
self._register_kv_args()
|
||||
else:
|
||||
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
|
||||
|
||||
assert len(self.bootstrap_infos) > 0
|
||||
|
||||
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
||||
"""Fetch the bootstrap info from the bootstrap server."""
|
||||
try:
|
||||
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
bootstrap_info = response.json()
|
||||
return bootstrap_info
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
||||
return None
|
||||
|
||||
def _get_prefill_dp_size_from_server(self) -> int:
|
||||
"""Fetch the prefill parallel info from the bootstrap server."""
|
||||
try:
|
||||
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
prefill_parallel_info = response.json()
|
||||
return int(prefill_parallel_info["prefill_tp_size"]), int(
|
||||
prefill_parallel_info["prefill_dp_size"]
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
||||
with cls._global_lock:
|
||||
if endpoint not in cls._socket_cache:
|
||||
sock = cls._ctx.socket(zmq.PUSH)
|
||||
if is_ipv6:
|
||||
sock.setsockopt(zmq.IPV6, 1)
|
||||
sock.connect(endpoint)
|
||||
cls._socket_cache[endpoint] = sock
|
||||
cls._socket_locks[endpoint] = threading.Lock()
|
||||
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||
|
||||
@classmethod
|
||||
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
||||
ip_address = bootstrap_info["rank_ip"]
|
||||
port = bootstrap_info["rank_port"]
|
||||
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
||||
sock, lock = cls._connect(
|
||||
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
||||
)
|
||||
return sock, lock
|
||||
|
||||
def _register_kv_args(self):
|
||||
pass
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVReceiver Exception")
|
||||
|
||||
|
||||
class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
||||
def __init__(self, host: str, port: int):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = web.Application()
|
||||
self.store = dict()
|
||||
self.lock = asyncio.Lock()
|
||||
self._setup_routes()
|
||||
self.tp_size = None
|
||||
self.dp_size = None
|
||||
self.tp_size_per_dp_rank = None
|
||||
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
||||
|
||||
# Start bootstrap server
|
||||
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
||||
self.run()
|
||||
|
||||
def run(self):
|
||||
self.thread.start()
|
||||
|
||||
def _setup_routes(self):
|
||||
self.app.router.add_route("*", "/route", self._handle_route)
|
||||
|
||||
async def _handle_route(self, request: web.Request):
|
||||
method = request.method
|
||||
if method == "PUT":
|
||||
return await self._handle_route_put(request)
|
||||
elif method == "GET":
|
||||
return await self._handle_route_get(request)
|
||||
else:
|
||||
return web.Response(
|
||||
text="Method not allowed", status=405, content_type="application/json"
|
||||
)
|
||||
|
||||
async def _handle_route_put(self, request: web.Request):
|
||||
data = await request.json()
|
||||
role = data["role"]
|
||||
tp_size = data["tp_size"]
|
||||
dp_size = data["dp_size"]
|
||||
rank_ip = data["rank_ip"]
|
||||
rank_port = int(data["rank_port"])
|
||||
engine_rank = int(data["engine_rank"])
|
||||
|
||||
if self.tp_size is None:
|
||||
self.tp_size = tp_size
|
||||
|
||||
if self.dp_size is None:
|
||||
self.dp_size = dp_size
|
||||
|
||||
tp_size_per_dp_rank = tp_size // dp_size
|
||||
if self.tp_size_per_dp_rank == None:
|
||||
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
||||
|
||||
# Add lock to make sure thread-safe
|
||||
if role == "Prefill":
|
||||
dp_group = engine_rank // tp_size_per_dp_rank
|
||||
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
|
||||
|
||||
async with self.lock:
|
||||
if dp_group not in self.prefill_port_table:
|
||||
self.prefill_port_table[dp_group] = {}
|
||||
|
||||
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
|
||||
"rank_ip": rank_ip,
|
||||
"rank_port": rank_port,
|
||||
}
|
||||
logger.debug(
|
||||
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||
)
|
||||
|
||||
return web.Response(text="OK", status=200)
|
||||
|
||||
async def _handle_route_get(self, request: web.Request):
|
||||
engine_rank = request.query.get("engine_rank")
|
||||
target_dp_group = request.query.get("target_dp_group")
|
||||
if not engine_rank or not target_dp_group:
|
||||
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
||||
|
||||
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
||||
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
||||
prefill_parallel_info = {
|
||||
"prefill_tp_size": self.tp_size,
|
||||
"prefill_dp_size": self.dp_size,
|
||||
}
|
||||
return web.json_response(prefill_parallel_info, status=200)
|
||||
|
||||
# Find corresponding prefill info
|
||||
async with self.lock:
|
||||
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
||||
int(engine_rank)
|
||||
]
|
||||
|
||||
if bootstrap_info is not None:
|
||||
return web.json_response(bootstrap_info, status=200)
|
||||
else:
|
||||
return web.Response(text="Bootstrap info not Found", status=404)
|
||||
|
||||
def _run_server(self):
|
||||
try:
|
||||
# Event Loop
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
self._runner = web.AppRunner(self.app)
|
||||
self._loop.run_until_complete(self._runner.setup())
|
||||
|
||||
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|
||||
self._loop.run_until_complete(site.start())
|
||||
self._loop.run_forever()
|
||||
except Exception as e:
|
||||
logger.error(f"Server error: {str(e)}")
|
||||
finally:
|
||||
# Cleanup
|
||||
self._loop.run_until_complete(self._runner.cleanup())
|
||||
self._loop.close()
|
||||
|
||||
def close(self):
|
||||
"""Shutdown"""
|
||||
if self._loop is not None and self._loop.is_running():
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
logger.info("Stopping server loop...")
|
||||
|
||||
if self.thread.is_alive():
|
||||
self.thread.join(timeout=2)
|
||||
logger.info("Server thread stopped")
|
||||
|
||||
def poll(self) -> KVPoll: ...
|
||||
42
python/sglang/srt/disaggregation/common/utils.py
Normal file
42
python/sglang/srt/disaggregation/common/utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class FastQueue:
|
||||
def __init__(self):
|
||||
self._buf = deque()
|
||||
self._cond = threading.Condition()
|
||||
|
||||
def put(self, item):
|
||||
with self._cond:
|
||||
self._buf.append(item)
|
||||
# wake up a thread of wait()
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._cond:
|
||||
# if queue is empty ,block until is notified()
|
||||
while not self._buf:
|
||||
self._cond.wait()
|
||||
return self._buf.popleft()
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]
|
||||
) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:
|
||||
"""Vectorised NumPy implementation."""
|
||||
if src_indices.size == 0:
|
||||
return [], []
|
||||
|
||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||
src_groups = np.split(src_indices, brk)
|
||||
dst_groups = np.split(dst_indices, brk)
|
||||
|
||||
src_groups = [g.tolist() for g in src_groups]
|
||||
dst_groups = [g.tolist() for g in dst_groups]
|
||||
|
||||
return src_groups, dst_groups
|
||||
894
python/sglang/srt/disaggregation/decode.py
Normal file
894
python/sglang/srt/disaggregation/decode.py
Normal file
@@ -0,0 +1,894 @@
|
||||
"""
|
||||
Life cycle of a request in the decode server
|
||||
|
||||
1. PreallocQueue:
|
||||
a. Initialize a receiver for each request
|
||||
b. The request handshakes first, and pre-allocate kv once there is available kv.
|
||||
c. Move the request to TransferQueue.
|
||||
|
||||
2. TransferQueue:
|
||||
a. Poll the receiver to check the transfer state
|
||||
b. If the transfer has finished, move the request to waiting queue
|
||||
|
||||
3. WaitingQueue:
|
||||
a. Use the requests in the queue to construct a PrebuiltExtendBatch
|
||||
b. Skip the prefill forward but only populate metadata
|
||||
|
||||
4. RunningBatch:
|
||||
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
KVClassType,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
is_mla_backend,
|
||||
kv_to_page_indices,
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
|
||||
CLIP_MAX_NEW_TOKEN = get_int_env_var("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", 4096)
|
||||
|
||||
|
||||
class DecodeReqToTokenPool:
|
||||
"""
|
||||
The difference of DecodeReqToTokenPool and ReqToTokenPool is that
|
||||
DecodeReqToTokenPool subscribes memory for pre-allocated requests.
|
||||
|
||||
In ReqToTokenPool, if `--max-running-requests` is 8,
|
||||
#pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.
|
||||
|
||||
In DecodeReqToTokenPool, if `--max-running-requests` is 8,
|
||||
#running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
pre_alloc_size: int,
|
||||
):
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
self.pre_alloc_size = pre_alloc_size
|
||||
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
|
||||
self.req_to_token = torch.zeros(
|
||||
(size + pre_alloc_size, max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.free_slots = list(range(size + pre_alloc_size))
|
||||
|
||||
def write(self, indices, values):
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int) -> List[int]:
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: Union[int, List[int]]):
|
||||
if isinstance(free_index, (int,)):
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.free_slots.extend(free_index)
|
||||
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeRequest:
|
||||
req: Req
|
||||
kv_receiver: BaseKVReceiver
|
||||
waiting_for_input: bool = False
|
||||
metadata_buffer_index: int = -1
|
||||
|
||||
|
||||
class DecodePreallocQueue:
|
||||
"""
|
||||
Store the requests that are preallocating.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
scheduler: Scheduler,
|
||||
transfer_queue: DecodeTransferQueue,
|
||||
tree_cache: BasePrefixCache,
|
||||
gloo_group: ProcessGroup,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
dp_size: int,
|
||||
gpu_id: int,
|
||||
bootstrap_port: int,
|
||||
max_total_num_tokens: int,
|
||||
prefill_pp_size: int,
|
||||
num_reserved_decode_tokens: int,
|
||||
transfer_backend: TransferBackend,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.scheduler = scheduler
|
||||
self.transfer_queue = transfer_queue
|
||||
self.tree_cache = tree_cache # this is always a chunk cache
|
||||
self.gloo_group = gloo_group
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.dp_size = dp_size
|
||||
self.gpu_id = gpu_id
|
||||
self.bootstrap_port = bootstrap_port
|
||||
self.max_total_num_tokens = max_total_num_tokens
|
||||
self.prefill_pp_size = prefill_pp_size
|
||||
self.num_reserved_decode_tokens = num_reserved_decode_tokens
|
||||
self.transfer_backend = transfer_backend
|
||||
# Queue for requests pending pre-allocation
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.retracted_queue: List[Req] = []
|
||||
self.prefill_pp_size = prefill_pp_size
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
|
||||
def _init_kv_manager(self) -> BaseKVManager:
|
||||
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
||||
kv_args = kv_args_class()
|
||||
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
|
||||
|
||||
kv_args.decode_tp_size = attn_tp_size
|
||||
# Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0
|
||||
kv_args.pp_rank = 0
|
||||
kv_args.system_dp_rank = self.scheduler.dp_rank
|
||||
kv_args.prefill_pp_size = self.prefill_pp_size
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
if self.draft_token_to_kv_pool is not None:
|
||||
# We should also transfer draft model kv cache. The indices are
|
||||
# always shared with a target model.
|
||||
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
||||
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
kv_data_ptrs += draft_kv_data_ptrs
|
||||
kv_data_lens += draft_kv_data_lens
|
||||
kv_item_lens += draft_kv_item_lens
|
||||
|
||||
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
||||
self.transfer_backend, KVClassType.MANAGER
|
||||
)
|
||||
kv_manager: BaseKVManager = kv_manager_class(
|
||||
kv_args,
|
||||
DisaggregationMode.DECODE,
|
||||
self.scheduler.server_args,
|
||||
self.is_mla_backend,
|
||||
)
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req, is_retracted: bool = False) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
if self._check_if_req_exceed_kv_capacity(req):
|
||||
return
|
||||
|
||||
if is_retracted:
|
||||
self.retracted_queue.append(req)
|
||||
else:
|
||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||
kv_receiver_class = get_kv_class(
|
||||
TransferBackend.FAKE, KVClassType.RECEIVER
|
||||
)
|
||||
else:
|
||||
kv_receiver_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.RECEIVER
|
||||
)
|
||||
|
||||
kv_receiver = kv_receiver_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
prefill_dp_rank=req.data_parallel_rank,
|
||||
)
|
||||
|
||||
self.queue.append(
|
||||
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
||||
)
|
||||
|
||||
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
||||
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
||||
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
||||
logger.error(message)
|
||||
prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
|
||||
self.scheduler.stream_output([req], req.return_logprob)
|
||||
return True
|
||||
return False
|
||||
|
||||
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
for req in reqs:
|
||||
self.add(req, is_retracted=is_retracted)
|
||||
|
||||
def resume_retracted_reqs(self) -> List[Req]:
|
||||
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
|
||||
|
||||
# allocate memory
|
||||
resumed_reqs = []
|
||||
indices_to_remove = set()
|
||||
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
|
||||
|
||||
for i, req in enumerate(self.retracted_queue):
|
||||
if self.req_to_token_pool.available_size() <= 0:
|
||||
break
|
||||
|
||||
required_tokens_for_request = (
|
||||
len(req.origin_input_ids)
|
||||
+ len(req.output_ids)
|
||||
+ self.num_reserved_decode_tokens
|
||||
)
|
||||
if required_tokens_for_request > allocatable_tokens:
|
||||
break
|
||||
|
||||
resumed_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
req.is_retracted = False
|
||||
self._pre_alloc(req)
|
||||
allocatable_tokens -= required_tokens_for_request
|
||||
|
||||
# load from cpu, release the cpu copy
|
||||
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
|
||||
|
||||
self.retracted_queue = [
|
||||
entry
|
||||
for i, entry in enumerate(self.retracted_queue)
|
||||
if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return resumed_reqs
|
||||
|
||||
def _update_handshake_waiters(self) -> None:
|
||||
if not self.queue:
|
||||
return
|
||||
|
||||
if all(decode_req.waiting_for_input for decode_req in self.queue):
|
||||
return
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if poll == KVPoll.Bootstrapping:
|
||||
pass
|
||||
elif poll == KVPoll.WaitingForInput:
|
||||
decode_req.waiting_for_input = True
|
||||
elif poll == KVPoll.Failed:
|
||||
error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||
try:
|
||||
decode_req.kv_receiver.failure_exception()
|
||||
except Exception as e:
|
||||
error_message += f" with exception {e}"
|
||||
logger.error(error_message)
|
||||
prepare_abort(
|
||||
decode_req.req,
|
||||
error_message,
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
if self.scheduler.enable_metrics:
|
||||
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
|
||||
else:
|
||||
raise ValueError(f"Unexpected poll case: {poll}")
|
||||
|
||||
def pop_preallocated(self) -> List[DecodeRequest]:
|
||||
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
||||
self._update_handshake_waiters()
|
||||
|
||||
preallocated_reqs = []
|
||||
indices_to_remove = set()
|
||||
|
||||
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
|
||||
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
|
||||
retractable_tokens = sum(
|
||||
len(r.origin_input_ids) + len(r.output_ids)
|
||||
for r in self.scheduler.running_batch.reqs
|
||||
)
|
||||
allocatable_tokens = self._allocatable_tokens(
|
||||
retractable_tokens=retractable_tokens, count_retracted=True
|
||||
)
|
||||
# First, remove all failed requests from the queue
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
# Then, preallocate the remaining requests if possible
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if i in indices_to_remove:
|
||||
continue
|
||||
|
||||
if not decode_req.waiting_for_input:
|
||||
continue
|
||||
|
||||
if self.req_to_token_pool.available_size() <= 0:
|
||||
break
|
||||
|
||||
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
||||
break
|
||||
|
||||
# Memory estimation: don't add if the projected memory cannot be met
|
||||
# TODO: add new_token ratio
|
||||
origin_input_len = len(decode_req.req.origin_input_ids)
|
||||
required_tokens_for_request = (
|
||||
origin_input_len + self.num_reserved_decode_tokens
|
||||
)
|
||||
|
||||
if (
|
||||
max(
|
||||
required_tokens_for_request,
|
||||
origin_input_len
|
||||
+ min(
|
||||
decode_req.req.sampling_params.max_new_tokens,
|
||||
CLIP_MAX_NEW_TOKEN,
|
||||
)
|
||||
- retractable_tokens,
|
||||
)
|
||||
> allocatable_tokens
|
||||
):
|
||||
break
|
||||
if required_tokens_for_request > allocatable_tokens:
|
||||
break
|
||||
|
||||
allocatable_tokens -= required_tokens_for_request
|
||||
self._pre_alloc(decode_req.req)
|
||||
|
||||
kv_indices = (
|
||||
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
|
||||
: len(decode_req.req.origin_input_ids)
|
||||
]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
decode_req.metadata_buffer_index = (
|
||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||
)
|
||||
assert decode_req.metadata_buffer_index is not None
|
||||
page_indices = kv_to_page_indices(
|
||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
||||
preallocated_reqs.append(decode_req)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
self.queue = [
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return preallocated_reqs
|
||||
|
||||
@property
|
||||
def num_tokens_pre_allocated(self):
|
||||
return sum(
|
||||
len(decode_req.req.fill_ids) for decode_req in self.transfer_queue.queue
|
||||
)
|
||||
|
||||
def _allocatable_tokens(
|
||||
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
|
||||
) -> int:
|
||||
need_space_for_single_req = (
|
||||
max(
|
||||
[
|
||||
min(x.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKEN)
|
||||
+ len(x.origin_input_ids)
|
||||
- retractable_tokens
|
||||
for x in self.scheduler.running_batch.reqs
|
||||
]
|
||||
)
|
||||
if retractable_tokens is not None
|
||||
and len(self.scheduler.running_batch.reqs) > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
if self.scheduler.model_config.is_hybrid:
|
||||
available_size = min(
|
||||
self.token_to_kv_pool_allocator.full_available_size(),
|
||||
self.token_to_kv_pool_allocator.swa_available_size(),
|
||||
)
|
||||
else:
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
|
||||
allocatable_tokens = available_size - max(
|
||||
# preserve some space for future decode
|
||||
self.num_reserved_decode_tokens
|
||||
* (
|
||||
len(self.scheduler.running_batch.reqs)
|
||||
+ len(self.transfer_queue.queue)
|
||||
+ len(self.scheduler.waiting_queue)
|
||||
),
|
||||
# make sure each request can finish if reach max_tokens with all other requests retracted
|
||||
need_space_for_single_req,
|
||||
)
|
||||
|
||||
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
||||
# the extend batch is not in any queue, so we need to explicitly add the tokens slots here
|
||||
if (
|
||||
self.scheduler.last_batch
|
||||
and self.scheduler.last_batch.forward_mode.is_extend()
|
||||
):
|
||||
allocatable_tokens -= self.num_reserved_decode_tokens * len(
|
||||
self.scheduler.last_batch.reqs
|
||||
)
|
||||
|
||||
if count_retracted:
|
||||
allocatable_tokens -= sum(
|
||||
[
|
||||
len(req.origin_input_ids)
|
||||
+ len(req.output_ids)
|
||||
+ self.num_reserved_decode_tokens
|
||||
for req in self.retracted_queue
|
||||
]
|
||||
)
|
||||
return allocatable_tokens
|
||||
|
||||
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
||||
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
||||
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||
|
||||
assert (
|
||||
req_pool_indices is not None
|
||||
), "req_pool_indices is full! There is a bug in memory estimation."
|
||||
|
||||
req.req_pool_idx = req_pool_indices[0]
|
||||
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
)
|
||||
else:
|
||||
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||
prefix_lens=torch.tensor(
|
||||
[0],
|
||||
dtype=torch.int64,
|
||||
device=self.token_to_kv_pool_allocator.device,
|
||||
),
|
||||
seq_lens=torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int64,
|
||||
device=self.token_to_kv_pool_allocator.device,
|
||||
),
|
||||
last_loc=torch.tensor(
|
||||
[-1],
|
||||
dtype=torch.int64,
|
||||
device=self.token_to_kv_pool_allocator.device,
|
||||
),
|
||||
extend_num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
assert (
|
||||
kv_loc is not None
|
||||
), "KV cache is full! There is a bug in memory estimation."
|
||||
|
||||
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
||||
|
||||
# populate metadata
|
||||
req.fill_ids = req.origin_input_ids + req.output_ids
|
||||
req.extend_input_len = len(req.origin_input_ids)
|
||||
|
||||
return kv_loc
|
||||
|
||||
|
||||
class DecodeTransferQueue:
|
||||
"""
|
||||
Store the requests that is polling kv
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gloo_group: ProcessGroup,
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
tp_rank: int,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
scheduler: Scheduler,
|
||||
tree_cache: BasePrefixCache,
|
||||
):
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.gloo_group = gloo_group
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.tp_rank = tp_rank
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.scheduler = scheduler
|
||||
self.tree_cache = tree_cache
|
||||
self.spec_algorithm = scheduler.spec_algorithm
|
||||
|
||||
def add(self, decode_req: DecodeRequest) -> None:
|
||||
self.queue.append(decode_req)
|
||||
|
||||
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
||||
self.queue.extend(decode_reqs)
|
||||
|
||||
def pop_transferred(self) -> List[Req]:
|
||||
if not self.queue:
|
||||
return []
|
||||
polls = poll_and_all_reduce(
|
||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
transferred_reqs = []
|
||||
indices_to_remove = set()
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if poll == KVPoll.Failed:
|
||||
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||
try:
|
||||
decode_req.kv_receiver.failure_exception()
|
||||
except Exception as e:
|
||||
error_message += f" with exception {e}"
|
||||
logger.error(error_message)
|
||||
prepare_abort(
|
||||
decode_req.req,
|
||||
error_message,
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
)
|
||||
# unlock the kv cache or it will have memory leak
|
||||
self.tree_cache.cache_finished_req(decode_req.req)
|
||||
indices_to_remove.add(i)
|
||||
if self.scheduler.enable_metrics:
|
||||
self.scheduler.metrics_collector.increment_transfer_failed_reqs()
|
||||
continue
|
||||
elif poll == KVPoll.Success:
|
||||
|
||||
idx = decode_req.metadata_buffer_index
|
||||
(
|
||||
output_id,
|
||||
output_token_logprobs_val,
|
||||
output_token_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
output_hidden_states,
|
||||
) = self.metadata_buffers.get_buf(idx)
|
||||
|
||||
decode_req.req.output_ids.append(output_id[0].item())
|
||||
if not self.spec_algorithm.is_none():
|
||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||
if decode_req.req.return_logprob:
|
||||
decode_req.req.output_token_logprobs_val.append(
|
||||
output_token_logprobs_val[0].item()
|
||||
)
|
||||
decode_req.req.output_token_logprobs_idx.append(
|
||||
output_token_logprobs_idx[0].item()
|
||||
)
|
||||
decode_req.req.output_top_logprobs_val.append(
|
||||
output_top_logprobs_val[
|
||||
: decode_req.req.top_logprobs_num
|
||||
].tolist()
|
||||
)
|
||||
decode_req.req.output_top_logprobs_idx.append(
|
||||
output_top_logprobs_idx[
|
||||
: decode_req.req.top_logprobs_num
|
||||
].tolist()
|
||||
)
|
||||
|
||||
if hasattr(decode_req.kv_receiver, "clear"):
|
||||
decode_req.kv_receiver.clear()
|
||||
|
||||
# special handling for sampling_params.max_new_tokens == 1
|
||||
if decode_req.req.sampling_params.max_new_tokens == 1:
|
||||
# finish immediately
|
||||
decode_req.req.check_finished()
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
)
|
||||
self.tree_cache.cache_finished_req(decode_req.req)
|
||||
else:
|
||||
transferred_reqs.append(decode_req.req)
|
||||
|
||||
indices_to_remove.add(i)
|
||||
elif poll in [
|
||||
KVPoll.Bootstrapping,
|
||||
KVPoll.WaitingForInput,
|
||||
KVPoll.Transferring,
|
||||
]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unexpected poll case: {poll}")
|
||||
|
||||
for i in indices_to_remove:
|
||||
idx = self.queue[i].metadata_buffer_index
|
||||
assert idx != -1
|
||||
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
||||
|
||||
self.queue = [
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return transferred_reqs
|
||||
|
||||
|
||||
class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_decode(self: Scheduler):
|
||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
# polling and allocating kv cache
|
||||
self.process_decode_queue()
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_mlp_sync_flag:
|
||||
self._prepare_idle_batch_and_run(None)
|
||||
else:
|
||||
if prepare_mlp_sync_flag:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||
|
||||
if batch is None and (
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_decode(self: Scheduler):
|
||||
result_queue = deque()
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
# polling and allocating kv cache
|
||||
self.process_decode_queue()
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
last_batch_in_queue = False
|
||||
|
||||
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_mlp_sync_flag:
|
||||
batch_, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
if batch_:
|
||||
result_queue.append((batch_.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
else:
|
||||
if prepare_mlp_sync_flag:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
last_batch_in_queue = True
|
||||
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
if batch:
|
||||
result_queue.append((batch.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
|
||||
# Process the results of the previous batch but skip if the last batch is extend
|
||||
if self.last_batch and self.last_batch_in_queue:
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
if batch is None and (
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.last_batch = batch
|
||||
self.last_batch_in_queue = last_batch_in_queue
|
||||
|
||||
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
||||
batch = self.prepare_mlp_sync_batch(batch)
|
||||
result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
if not delay_process:
|
||||
self.process_batch_result(batch, result)
|
||||
return batch, result
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
self: Scheduler,
|
||||
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
||||
"""Create fake completed prefill if possible and merge with running batch"""
|
||||
# Merge the prefill batch into the running batch
|
||||
last_batch = self.last_batch
|
||||
if last_batch and last_batch.forward_mode.is_extend():
|
||||
# chunked prefill doesn't happen in decode instance.
|
||||
assert self.chunked_req is None
|
||||
# Filter finished batches.
|
||||
last_batch.filter_batch()
|
||||
if not last_batch.is_empty():
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = last_batch
|
||||
else:
|
||||
# merge running_batch with prefill batch
|
||||
self.running_batch.merge_batch(last_batch)
|
||||
|
||||
new_prebuilt_batch = self.get_new_prebuilt_batch()
|
||||
|
||||
ret: Optional[ScheduleBatch] = None
|
||||
if new_prebuilt_batch:
|
||||
ret = new_prebuilt_batch
|
||||
else:
|
||||
if self.running_batch.is_empty():
|
||||
ret = None
|
||||
else:
|
||||
self.running_batch = self.update_running_batch(self.running_batch)
|
||||
ret = self.running_batch if not self.running_batch.is_empty() else None
|
||||
|
||||
return ret
|
||||
|
||||
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
||||
"""Create a schedulebatch for fake completed prefill"""
|
||||
if self.grammar_queue:
|
||||
self.move_ready_grammar_requests()
|
||||
|
||||
if len(self.waiting_queue) == 0:
|
||||
return None
|
||||
|
||||
curr_batch_size = self.running_batch.batch_size()
|
||||
|
||||
batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
|
||||
|
||||
num_not_used_batch = batch_size - curr_batch_size
|
||||
|
||||
# pop req from waiting queue
|
||||
can_run_list: List[Req] = []
|
||||
waiting_queue: List[Req] = []
|
||||
|
||||
for i in range(len(self.waiting_queue)):
|
||||
req = self.waiting_queue[i]
|
||||
# we can only add at least `num_not_used_batch` new batch to the running queue
|
||||
if i < num_not_used_batch:
|
||||
can_run_list.append(req)
|
||||
req.init_next_round_input(self.tree_cache)
|
||||
else:
|
||||
waiting_queue.append(req)
|
||||
|
||||
self.waiting_queue = waiting_queue
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
# construct a schedule batch with those requests and mark as decode
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
)
|
||||
|
||||
# construct fake completed prefill
|
||||
new_batch.prepare_for_prebuilt_extend()
|
||||
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
|
||||
|
||||
return new_batch
|
||||
|
||||
def process_decode_queue(self: Scheduler):
|
||||
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
|
||||
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
|
||||
self.waiting_queue.extend(resumed_reqs)
|
||||
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
|
||||
# if there are still retracted requests, we do not allocate new requests
|
||||
return
|
||||
|
||||
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||
alloc_reqs = (
|
||||
self.disagg_decode_transfer_queue.pop_transferred()
|
||||
) # the requests which kv has arrived
|
||||
self.waiting_queue.extend(alloc_reqs)
|
||||
159
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
Normal file
159
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.utils import prepare_abort
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
class ScheduleBatchDisaggregationDecodeMixin:
|
||||
|
||||
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
||||
"""
|
||||
Prepare a prebuilt extend by populate metadata
|
||||
Adapted from .prepare_for_extend().
|
||||
"""
|
||||
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
pre_lens = []
|
||||
req_pool_indices = []
|
||||
|
||||
# Pre-calculate total size
|
||||
total_size = sum(req.extend_input_len for req in reqs)
|
||||
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Fill the tensor in one pass
|
||||
offset = 0
|
||||
for i, req in enumerate(reqs):
|
||||
req_pool_indices.append(req.req_pool_idx)
|
||||
|
||||
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
: req.extend_input_len
|
||||
]
|
||||
assert (
|
||||
offset + req.extend_input_len <= total_size
|
||||
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
||||
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
||||
offset += req.extend_input_len
|
||||
|
||||
pre_len = len(req.prefix_indices)
|
||||
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
||||
seq_lens.append(seq_len)
|
||||
if len(req.output_ids) == 0:
|
||||
assert (
|
||||
seq_len - pre_len == req.extend_input_len
|
||||
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
||||
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
req.extend_logprob_start_len = 0
|
||||
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(
|
||||
sum(input_ids, []), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(
|
||||
req_pool_indices, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||
self.orig_seq_lens = torch.tensor(
|
||||
seq_lens, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
self.multimodal_inputs = [r.multimodal_inputs for r in reqs]
|
||||
|
||||
# Build sampling info
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
)
|
||||
|
||||
def process_prebuilt_extend(
|
||||
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
||||
):
|
||||
"""Assign the buffered last input id to schedule batch"""
|
||||
self.output_ids = []
|
||||
for req in self.reqs:
|
||||
self.output_ids.append(req.output_ids[-1])
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
if req.grammar is not None:
|
||||
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
||||
try:
|
||||
req.grammar.accept_token(req.output_ids[-1])
|
||||
except ValueError as e:
|
||||
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
||||
# This can happen if the grammar is not set correctly or the token is invalid.
|
||||
error_message = f"Grammar accept_token failed for req {req.rid} with token {req.output_ids[-1]}: {e}"
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
prepare_abort(
|
||||
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
req.grammar.finished = req.finished()
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
|
||||
# Simulate the eagle run. We add mock data to hidden states for the
|
||||
# ease of implementation now meaning the first token will have acc rate
|
||||
# of 0.
|
||||
if not self.spec_algorithm.is_none():
|
||||
|
||||
b = len(self.reqs)
|
||||
topk_p = torch.arange(
|
||||
b * server_args.speculative_eagle_topk,
|
||||
0,
|
||||
-1,
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
|
||||
topk_p /= b * server_args.speculative_eagle_topk
|
||||
topk_index = torch.arange(
|
||||
b * server_args.speculative_eagle_topk, device=self.device
|
||||
)
|
||||
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
||||
|
||||
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
||||
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
||||
|
||||
# local import to avoid circular import
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
topk_p=topk_p,
|
||||
topk_index=topk_index,
|
||||
hidden_states=hidden_states,
|
||||
verified_id=self.output_ids,
|
||||
)
|
||||
spec_info.prepare_for_extend(self)
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
self.spec_info = spec_info
|
||||
1
python/sglang/srt/disaggregation/fake/__init__.py
Normal file
1
python/sglang/srt/disaggregation/fake/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
|
||||
85
python/sglang/srt/disaggregation/fake/conn.py
Normal file
85
python/sglang/srt/disaggregation/fake/conn.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from sglang.srt.disaggregation.base.conn import (
|
||||
BaseKVManager,
|
||||
BaseKVReceiver,
|
||||
BaseKVSender,
|
||||
KVPoll,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
||||
class FakeKVSender(BaseKVSender):
|
||||
def __init__(
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: int,
|
||||
dest_tp_ranks: List[int],
|
||||
pp_rank: int,
|
||||
):
|
||||
self.has_sent = False
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.has_sent is False:
|
||||
# Assume handshake completed instantly
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
logger.debug("FakeKVSender poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(
|
||||
self,
|
||||
kv_indices: list[int],
|
||||
aux_index: Optional[int] = None,
|
||||
):
|
||||
logger.debug(
|
||||
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||
)
|
||||
pass
|
||||
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
):
|
||||
self.has_sent = True
|
||||
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVSender Exception")
|
||||
|
||||
|
||||
class FakeKVReceiver(BaseKVReceiver):
|
||||
def __init__(
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
prefill_dp_rank: Optional[int] = None,
|
||||
):
|
||||
self.has_init = False
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.has_init is False:
|
||||
# Assume handshake completed instantly
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
logger.debug("FakeKVReceiver poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
||||
self.has_init = True
|
||||
logger.debug(
|
||||
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||
)
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVReceiver Exception")
|
||||
412
python/sglang/srt/disaggregation/kv_events.py
Normal file
412
python/sglang/srt/disaggregation/kv_events.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Copyright 2025 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
KV caching events
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from itertools import count
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventBatch(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False, # type: ignore[call-arg]
|
||||
):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
attn_dp_rank: Optional[int] = None
|
||||
|
||||
|
||||
class KVCacheEvent(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False, # type: ignore[call-arg]
|
||||
tag=True,
|
||||
):
|
||||
"""Base class for all KV cache-related events"""
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
pass
|
||||
|
||||
|
||||
class KVEventBatch(EventBatch):
|
||||
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""
|
||||
Lightweight publisher for EventBatch batches with
|
||||
support for DP attention.
|
||||
|
||||
In DP attention - each rank has its own Scheduler and
|
||||
KV cache instance in order to avoid duplicate events
|
||||
and ensure proper event attribution. In our implementation
|
||||
|
||||
- Each DP rank has its own EventPublisher
|
||||
- Publishers annotate events with the dp rank
|
||||
- This allows consumers to distinguish events from different DP ranks
|
||||
"""
|
||||
|
||||
def __init__(self, attn_dp_rank: int = 0):
|
||||
self._attn_dp_rank = attn_dp_rank
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
"""Emit events in order.
|
||||
|
||||
Implementations should guarantee at-least-once delivery and
|
||||
monotonic ordering (e.g., via sequence numbers).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the publisher."""
|
||||
|
||||
|
||||
class NullEventPublisher(EventPublisher):
|
||||
"""No-op implementation (default when disabled)."""
|
||||
|
||||
def publish(self, events) -> None:
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
return
|
||||
|
||||
|
||||
class ZmqEventPublisher(EventPublisher):
|
||||
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
||||
|
||||
Spawns a separate thread to handle publishing from a queue.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
endpoint:
|
||||
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
|
||||
connect.
|
||||
replay_endpoint:
|
||||
Optional ROUTER address for replay requests. When given, subscribers can
|
||||
request missed batches by sending the starting sequence number as an
|
||||
8-byte big-endian integer.
|
||||
buffer_steps:
|
||||
Number of past batches to keep for replay.
|
||||
hwm:
|
||||
ZeroMQ high-water-mark for PUB socket.
|
||||
max_queue_size:
|
||||
Maximum number of events to buffer in memory.
|
||||
topic:
|
||||
Topic to publish events to.
|
||||
"""
|
||||
|
||||
SHUTDOWN_TIMEOUT: float = 1.0
|
||||
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_dp_rank: int,
|
||||
endpoint: str = "tcp://*:5557",
|
||||
replay_endpoint: Optional[str] = None,
|
||||
buffer_steps: int = 10_000,
|
||||
hwm: int = 100_000,
|
||||
max_queue_size: int = 100_000,
|
||||
topic: str = "",
|
||||
) -> None:
|
||||
# Storage
|
||||
super().__init__(attn_dp_rank)
|
||||
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||
|
||||
# ZMQ sockets
|
||||
self._ctx = zmq.Context.instance()
|
||||
self._pub: Optional[zmq.Socket] = None
|
||||
self._replay: Optional[zmq.Socket] = None
|
||||
self._dp_rank = attn_dp_rank
|
||||
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
||||
self._replay_endpoint = self.offset_endpoint_port(
|
||||
replay_endpoint, self._dp_rank
|
||||
)
|
||||
self._hwm = hwm
|
||||
self._socket_setup()
|
||||
|
||||
# Payload
|
||||
self._seq_gen = count()
|
||||
self._topic_bytes = topic.encode("utf-8")
|
||||
|
||||
# Thread
|
||||
self._running = True
|
||||
logger.info("Starting ZMQ publisher thread")
|
||||
|
||||
self._thread = threading.Thread(
|
||||
target=self._publisher_thread, daemon=True, name="zmq-publisher"
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
if not self._running:
|
||||
raise RuntimeError("Publisher is closed")
|
||||
if events.attn_dp_rank is None:
|
||||
events.attn_dp_rank = self._dp_rank
|
||||
self._event_queue.put(events)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Stop the publisher thread and clean up resources."""
|
||||
self._running = False
|
||||
self._event_queue.put_nowait(None)
|
||||
|
||||
start = time.time()
|
||||
pending_items = True
|
||||
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
||||
pending_items = not self._event_queue.empty()
|
||||
if pending_items:
|
||||
time.sleep(0.1)
|
||||
|
||||
if pending_items:
|
||||
logger.warning(
|
||||
"Warning: Queue still has %s items after %s seconds timeout",
|
||||
self._event_queue.qsize(),
|
||||
self.SHUTDOWN_TIMEOUT,
|
||||
)
|
||||
|
||||
if self._thread.is_alive():
|
||||
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
||||
|
||||
# Clean up ZMQ resources
|
||||
try:
|
||||
if self._pub is not None:
|
||||
self._pub.close(linger=0)
|
||||
if self._replay is not None:
|
||||
self._replay.close(linger=0)
|
||||
finally:
|
||||
pass # Do not terminate context; other sockets may use it
|
||||
|
||||
def _socket_setup(self) -> None:
|
||||
"""Initialize sockets
|
||||
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
||||
"""
|
||||
if self._pub is None:
|
||||
self._pub = self._ctx.socket(zmq.PUB)
|
||||
self._pub.set_hwm(self._hwm)
|
||||
# Heuristic: bind if wildcard / * present, else connect.
|
||||
# bind stable, connect volatile convention
|
||||
if (
|
||||
"*" in self._endpoint
|
||||
or "::" in self._endpoint
|
||||
or self._endpoint.startswith("ipc://")
|
||||
or self._endpoint.startswith("inproc://")
|
||||
):
|
||||
self._pub.bind(self._endpoint)
|
||||
else:
|
||||
self._pub.connect(self._endpoint)
|
||||
|
||||
# Set up replay socket: use ROUTER
|
||||
# 1) handles multiple REQ clients (identities)
|
||||
# 2) lets us send back one request → many replies (streamed events)
|
||||
# 3) works in our non‑blocking poll loop alongside PUB
|
||||
if self._replay_endpoint is not None:
|
||||
self._replay = self._ctx.socket(zmq.ROUTER)
|
||||
self._replay.bind(self._replay_endpoint)
|
||||
|
||||
def _publisher_thread(self) -> None:
|
||||
"""Background thread that processes the event queue."""
|
||||
self._pack = msgspec.msgpack.Encoder()
|
||||
|
||||
assert self._pub is not None # narrows type for mypy
|
||||
|
||||
while self._running or self._event_queue.qsize() > 0:
|
||||
# --- replay (non-critical) ---------------------------------
|
||||
if self._replay is not None and self._replay.poll(0):
|
||||
try:
|
||||
self._service_replay()
|
||||
except Exception as e:
|
||||
logger.exception("Error in replay: %s", e)
|
||||
|
||||
# --- main queue (critical) ---------------------------------
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
if event is None:
|
||||
break # Sentinel received, exit thread
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
try:
|
||||
seq = next(self._seq_gen)
|
||||
|
||||
payload = self._pack.encode(event)
|
||||
seq_bytes = seq.to_bytes(8, "big")
|
||||
self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
|
||||
|
||||
self._buffer.append((seq, payload))
|
||||
self._event_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
# Publishing failed; back-off a bit to avoid a tight error loop
|
||||
logger.exception("Error in publisher thread: %s", e)
|
||||
time.sleep(0.1)
|
||||
|
||||
def _service_replay(self) -> None:
|
||||
"""If a replay request is waiting, send buffered batches."""
|
||||
assert self._replay is not None # narrows type for mypy
|
||||
|
||||
frame = self._replay.recv_multipart()
|
||||
if len(frame) != 3:
|
||||
logger.warning("Invalid replay request: %s", frame)
|
||||
return
|
||||
client_id, _, start_seq_bytes = frame
|
||||
start_seq = int.from_bytes(start_seq_bytes, "big")
|
||||
|
||||
for seq, buf in self._buffer:
|
||||
if seq >= start_seq:
|
||||
# [identity, empty_delim, seq_bytes, payload]
|
||||
# (identity, empty_delim) are stripped off by the router
|
||||
# receiving payload is (seq_bytes, payload)
|
||||
self._replay.send_multipart(
|
||||
(client_id, b"", seq.to_bytes(8, "big"), buf)
|
||||
)
|
||||
# Send end of sequence marker
|
||||
# receiving payload is (-1, b""")
|
||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||
|
||||
@staticmethod
|
||||
def offset_endpoint_port(
|
||||
endpoint: Optional[str], data_parallel_rank: int
|
||||
) -> Optional[str]:
|
||||
"""Helper function to offset the port in an endpoint by
|
||||
the data parallel rank.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint string
|
||||
(e.g., "tcp://*:5557" or "inproc://cache")
|
||||
data_parallel_rank: The data parallel rank to offset by
|
||||
|
||||
Returns:
|
||||
The endpoint with the port offset by data_parallel_rank
|
||||
or suffix appended
|
||||
"""
|
||||
# Do nothing if input is None or data_parallel_rank is 0
|
||||
if not endpoint or data_parallel_rank == 0:
|
||||
return endpoint
|
||||
|
||||
if "inproc" in endpoint:
|
||||
return f"{endpoint}_dp{data_parallel_rank}"
|
||||
if "tcp" in endpoint:
|
||||
if endpoint and ":" in endpoint:
|
||||
# Get everything after the last colon (the port)
|
||||
last_colon_idx = endpoint.rfind(":")
|
||||
base_addr = endpoint[:last_colon_idx]
|
||||
base_port = int(endpoint[last_colon_idx + 1 :])
|
||||
new_port = base_port + data_parallel_rank
|
||||
return f"{base_addr}:{new_port}"
|
||||
return endpoint
|
||||
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
||||
|
||||
|
||||
class KVEventsConfig(BaseModel):
|
||||
"""Configuration for KV event publishing."""
|
||||
|
||||
publisher: str = "null"
|
||||
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
||||
"""
|
||||
|
||||
endpoint: str = "tcp://*:5557"
|
||||
"""The zmq endpoint to use for publishing kv events.
|
||||
"""
|
||||
|
||||
replay_endpoint: Optional[str] = None
|
||||
"""The zmq endpoint to use for replaying kv events.
|
||||
"""
|
||||
|
||||
buffer_steps: int = 10_000
|
||||
"""The number of steps to cache for replay endpoint. Will only save
|
||||
events from the last N steps for the replay endpoint.
|
||||
"""
|
||||
|
||||
hwm: int = 100_000
|
||||
"""The zmq high water mark for the event publisher. After queueing N events,
|
||||
events will start dropping if the consumer is not keeping up.
|
||||
"""
|
||||
|
||||
max_queue_size: int = 100_000
|
||||
"""The maximum number of events to queue while waiting for publishing.
|
||||
"""
|
||||
|
||||
topic: str = ""
|
||||
"""The topic to use for the event publisher. Consumers can subscribe to
|
||||
this topic to receive events.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
||||
"""Parse the CLI value for the event publisher config."""
|
||||
return KVEventsConfig.model_validate_json(cli_value)
|
||||
|
||||
|
||||
class EventPublisherFactory:
|
||||
_registry: dict[str, Callable[..., EventPublisher]] = {
|
||||
"null": NullEventPublisher,
|
||||
"zmq": ZmqEventPublisher,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
|
||||
if name in cls._registry:
|
||||
raise KeyError(f"publisher '{name}' already registered")
|
||||
cls._registry[name] = ctor
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: Optional[str], attn_dp_rank: int = 0) -> EventPublisher:
|
||||
"""Create publisher from a config mapping."""
|
||||
if not config:
|
||||
return NullEventPublisher()
|
||||
config = KVEventsConfig.from_cli(config)
|
||||
config_dict = config.model_dump()
|
||||
|
||||
kind = config_dict.pop("publisher", "null")
|
||||
try:
|
||||
constructor = cls._registry[kind]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||
return constructor(attn_dp_rank=attn_dp_rank, **config_dict)
|
||||
6
python/sglang/srt/disaggregation/mini_lb.py
Normal file
6
python/sglang/srt/disaggregation/mini_lb.py
Normal file
@@ -0,0 +1,6 @@
|
||||
raise RuntimeError(
|
||||
"""The 'mini_lb' module has been relocated to the 'sglang_router' package.
|
||||
We recommend installing 'sglang-router' with Rust support for optimal performance.
|
||||
If you encounter issues building the router with Rust, set the environment variable
|
||||
'SGLANG_ROUTER_BUILD_NO_RUST=1' and add '--mini-lb' to the command line to use the Python version of 'mini_lb'."""
|
||||
)
|
||||
6
python/sglang/srt/disaggregation/mooncake/__init__.py
Normal file
6
python/sglang/srt/disaggregation/mooncake/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sglang.srt.disaggregation.mooncake.conn import (
|
||||
MooncakeKVBootstrapServer,
|
||||
MooncakeKVManager,
|
||||
MooncakeKVReceiver,
|
||||
MooncakeKVSender,
|
||||
)
|
||||
1704
python/sglang/srt/disaggregation/mooncake/conn.py
Normal file
1704
python/sglang/srt/disaggregation/mooncake/conn.py
Normal file
File diff suppressed because it is too large
Load Diff
164
python/sglang/srt/disaggregation/mooncake/transfer_engine.py
Normal file
164
python/sglang/srt/disaggregation/mooncake/transfer_engine.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MooncakeTransferEngine:
|
||||
|
||||
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run SGLang with MooncakeTransferEngine."
|
||||
) from e
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = hostname
|
||||
self.gpu_id = gpu_id
|
||||
self.ib_device = ib_device
|
||||
|
||||
self.initialize(
|
||||
hostname=self.hostname,
|
||||
device_name=self.ib_device,
|
||||
)
|
||||
self.session_id = (
|
||||
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
|
||||
)
|
||||
|
||||
def register(self, ptr, length):
|
||||
try:
|
||||
ret_value = self.engine.register_memory(ptr, length)
|
||||
except Exception:
|
||||
# Mark register as failed
|
||||
ret_value = -1
|
||||
|
||||
if ret_value != 0:
|
||||
logger.debug("Mooncake memory registration %s failed.", ptr)
|
||||
|
||||
def deregister(self, ptr):
|
||||
try:
|
||||
ret_value = self.engine.unregister_memory(ptr)
|
||||
except Exception:
|
||||
# Mark deregister as failed
|
||||
ret_value = -1
|
||||
|
||||
if ret_value != 0:
|
||||
logger.debug("Mooncake memory deregistration %s failed.", ptr)
|
||||
|
||||
def batch_register(self, ptrs: List[int], lengths: List[int]) -> int:
|
||||
"""Batch register multiple memory regions."""
|
||||
try:
|
||||
ret_value = self.engine.batch_register_memory(ptrs, lengths)
|
||||
except Exception:
|
||||
# Mark batch register as failed
|
||||
ret_value = -1
|
||||
if not hasattr(self.engine, "batch_register_memory"):
|
||||
raise RuntimeError(
|
||||
"Mooncake's batch register requires a newer version of mooncake-transfer-engine. "
|
||||
"Please upgrade Mooncake."
|
||||
)
|
||||
|
||||
if ret_value != 0:
|
||||
logger.debug("Mooncake batch memory registration failed.")
|
||||
return ret_value
|
||||
|
||||
def batch_deregister(self, ptrs: List[int]) -> int:
|
||||
"""Batch deregister multiple memory regions."""
|
||||
try:
|
||||
ret_value = self.engine.batch_unregister_memory(ptrs)
|
||||
except Exception:
|
||||
# Mark batch deregister as failed
|
||||
ret_value = -1
|
||||
|
||||
if ret_value != 0:
|
||||
logger.debug("Mooncake batch memory deregistration failed.")
|
||||
return ret_value
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
hostname: str,
|
||||
device_name: Optional[str],
|
||||
) -> None:
|
||||
"""Initialize the mooncake instance."""
|
||||
if get_bool_env_var("ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE", "false"):
|
||||
hostname += f":{get_free_port()}:npu_{self.gpu_id}"
|
||||
ret_value = self.engine.initialize(
|
||||
hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"ascend",
|
||||
device_name if device_name is not None else "",
|
||||
)
|
||||
else:
|
||||
ret_value = self.engine.initialize(
|
||||
hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"rdma",
|
||||
device_name if device_name is not None else "",
|
||||
)
|
||||
if ret_value != 0:
|
||||
logger.error("Mooncake Transfer Engine initialization failed.")
|
||||
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
||||
|
||||
def transfer_sync(
|
||||
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
||||
) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
try:
|
||||
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
|
||||
# later: based on the cached queue pair to send data
|
||||
ret = self.engine.transfer_sync_write(
|
||||
session_id, buffer, peer_buffer_address, length
|
||||
)
|
||||
except Exception:
|
||||
# Mark transfer request as failed
|
||||
ret = -1
|
||||
|
||||
if ret < 0:
|
||||
# Do not raise an exception here, since some transfer requests fail should be accepted and the execution thread should not be stopped.
|
||||
logger.debug(
|
||||
"Failed to transfer data from %s to %s - %s.",
|
||||
buffer,
|
||||
session_id,
|
||||
peer_buffer_address,
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
def batch_transfer_sync(
|
||||
self,
|
||||
session_id: str,
|
||||
buffers: List[int],
|
||||
peer_buffer_addresses: List[int],
|
||||
lengths: List[int],
|
||||
) -> int:
|
||||
"""Synchronously transfer data to the specified addresses in batches."""
|
||||
try:
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, buffers, peer_buffer_addresses, lengths
|
||||
)
|
||||
except Exception:
|
||||
ret = -1
|
||||
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
|
||||
if not hasattr(self.engine, "batch_transfer_sync_write"):
|
||||
raise RuntimeError(
|
||||
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
|
||||
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
|
||||
)
|
||||
|
||||
if ret < 0:
|
||||
logger.debug(
|
||||
"Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
|
||||
buffers,
|
||||
session_id,
|
||||
peer_buffer_addresses,
|
||||
)
|
||||
return ret
|
||||
|
||||
def get_session_id(self):
|
||||
return self.session_id
|
||||
6
python/sglang/srt/disaggregation/nixl/__init__.py
Normal file
6
python/sglang/srt/disaggregation/nixl/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sglang.srt.disaggregation.nixl.conn import (
|
||||
NixlKVBootstrapServer,
|
||||
NixlKVManager,
|
||||
NixlKVReceiver,
|
||||
NixlKVSender,
|
||||
)
|
||||
696
python/sglang/srt/disaggregation/nixl/conn.py
Normal file
696
python/sglang/srt/disaggregation/nixl/conn.py
Normal file
@@ -0,0 +1,696 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import queue
|
||||
import socket
|
||||
import struct
|
||||
import threading
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from functools import cache
|
||||
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import requests
|
||||
import zmq
|
||||
from aiohttp import web
|
||||
|
||||
from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.common.conn import (
|
||||
CommonKVBootstrapServer,
|
||||
CommonKVManager,
|
||||
CommonKVReceiver,
|
||||
)
|
||||
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
format_tcp_address,
|
||||
get_local_ip_auto,
|
||||
is_valid_ipv6_address,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GUARD = "NixlMsgGuard".encode("ascii")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TransferInfo:
|
||||
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
|
||||
|
||||
room: int
|
||||
endpoint: str
|
||||
dst_port: int
|
||||
agent_name: str
|
||||
dst_kv_indices: npt.NDArray[np.int32]
|
||||
dst_aux_index: int
|
||||
required_dst_info_num: int
|
||||
|
||||
def is_dummy(self):
|
||||
return self.dst_kv_indices.size == 0
|
||||
|
||||
@classmethod
|
||||
def from_zmq(cls, msg: List[bytes]):
|
||||
return cls(
|
||||
room=int(msg[0].decode("ascii")),
|
||||
endpoint=msg[1].decode("ascii"),
|
||||
dst_port=int(msg[2].decode("ascii")),
|
||||
agent_name=msg[3].decode("ascii"),
|
||||
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
|
||||
dst_aux_index=int(msg[5].decode("ascii")),
|
||||
required_dst_info_num=int(msg[6].decode("ascii")),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KVArgsRegisterInfo:
|
||||
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
|
||||
|
||||
room: str
|
||||
endpoint: str
|
||||
dst_port: int
|
||||
agent_name: str
|
||||
agent_metadata: bytes
|
||||
dst_kv_ptrs: list[int]
|
||||
dst_aux_ptrs: list[int]
|
||||
gpu_id: int
|
||||
decode_tp_size: int
|
||||
decode_tp_rank: int
|
||||
dst_kv_item_len: int
|
||||
|
||||
@classmethod
|
||||
def from_zmq(cls, msg: List[bytes]):
|
||||
return cls(
|
||||
room=str(msg[0].decode("ascii")),
|
||||
endpoint=msg[1].decode("ascii"),
|
||||
dst_port=int(msg[2].decode("ascii")),
|
||||
agent_name=msg[3].decode("ascii"),
|
||||
agent_metadata=msg[4],
|
||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
||||
gpu_id=int(msg[7].decode("ascii")),
|
||||
decode_tp_size=int(msg[8].decode("ascii")),
|
||||
decode_tp_rank=int(msg[9].decode("ascii")),
|
||||
dst_kv_item_len=int(msg[10].decode("ascii")),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TransferStatus:
|
||||
"""Used by KV Receiver to know when a transfer is done."""
|
||||
|
||||
# KV chunk IDs that have been received.
|
||||
received_kvs: Set[int] = dataclasses.field(default_factory=set)
|
||||
# Number of kv chunks to expect, will know this after last chunk is received.
|
||||
num_kvs_expected: Optional[int] = None
|
||||
# Whether aux data has been received.
|
||||
received_aux: bool = False
|
||||
|
||||
def is_done(self):
|
||||
if self.num_kvs_expected is None:
|
||||
return False
|
||||
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
||||
|
||||
|
||||
class NixlKVManager(CommonKVManager):
|
||||
def __init__(
|
||||
self,
|
||||
args: KVArgs,
|
||||
disaggregation_mode: DisaggregationMode,
|
||||
server_args: ServerArgs,
|
||||
is_mla_backend: Optional[bool] = False,
|
||||
):
|
||||
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
|
||||
try:
|
||||
from nixl._api import nixl_agent
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install NIXL by following the instructions at "
|
||||
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
|
||||
"to run SGLang with NixlTransferEngine."
|
||||
) from e
|
||||
self.agent = nixl_agent(str(uuid.uuid4()))
|
||||
self.local_ip = get_local_ip_auto()
|
||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||
if is_valid_ipv6_address(self.local_ip):
|
||||
self.server_socket.setsockopt(zmq.IPV6, 1)
|
||||
self.register_buffer_to_engine()
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.request_status: Dict[int, KVPoll] = {}
|
||||
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
||||
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
||||
self._start_bootstrap_thread()
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
||||
TransferStatus
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
||||
)
|
||||
|
||||
def check_status(self, bootstrap_room: int):
|
||||
return self.request_status[bootstrap_room]
|
||||
|
||||
def update_status(self, bootstrap_room: int, status: KVPoll):
|
||||
if bootstrap_room not in self.request_status:
|
||||
self.request_status[bootstrap_room] = status
|
||||
else:
|
||||
# NOTE: The prefill engine could recv bootstrapping first
|
||||
self.request_status[bootstrap_room] = max(
|
||||
self.request_status[bootstrap_room], status
|
||||
)
|
||||
|
||||
def register_buffer_to_engine(self):
|
||||
kv_addrs = []
|
||||
for kv_data_ptr, kv_data_len in zip(
|
||||
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
||||
):
|
||||
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
||||
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
|
||||
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
||||
if not self.kv_descs:
|
||||
raise Exception("NIXL memory registration failed for kv tensors")
|
||||
aux_addrs = []
|
||||
for aux_data_ptr, aux_data_len in zip(
|
||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||
):
|
||||
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
||||
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
|
||||
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
||||
if not self.aux_descs:
|
||||
raise Exception("NIXL memory registration failed for aux tensors")
|
||||
|
||||
def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
|
||||
agent_name = decode_kv_args.agent_name
|
||||
if agent_name in self.decode_kv_args_table:
|
||||
logger.info(f"Peer {agent_name} was already registered, ignoring.")
|
||||
return
|
||||
self.decode_kv_args_table[agent_name] = decode_kv_args
|
||||
self.agent.add_remote_agent(decode_kv_args.agent_metadata)
|
||||
|
||||
def send_kvcache(
|
||||
self,
|
||||
peer_name: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
dst_gpu_id: int,
|
||||
notif: str,
|
||||
):
|
||||
# group by indices
|
||||
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||
prefill_kv_indices, dst_kv_indices
|
||||
)
|
||||
|
||||
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
||||
# Make descs
|
||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||
src_addrs = []
|
||||
dst_addrs = []
|
||||
for layer_id in range(num_layers):
|
||||
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
||||
dst_ptr = dst_kv_ptrs[layer_id]
|
||||
item_len = self.kv_args.kv_item_lens[layer_id]
|
||||
|
||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||
length = item_len * len(prefill_index)
|
||||
src_addrs.append((src_addr, length, self.kv_args.gpu_id))
|
||||
dst_addrs.append((dst_addr, length, dst_gpu_id))
|
||||
|
||||
logger.debug(
|
||||
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
||||
)
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
||||
# Transfer data
|
||||
xfer_handle = self.agent.initialize_xfer(
|
||||
"WRITE",
|
||||
src_descs,
|
||||
dst_descs,
|
||||
peer_name,
|
||||
notif.encode("ascii"), # type: ignore
|
||||
)
|
||||
if not xfer_handle:
|
||||
raise Exception("KVSender failed to create transfer")
|
||||
state = self.agent.transfer(xfer_handle)
|
||||
if state == "ERR":
|
||||
raise Exception("KVSender failed to post transfer")
|
||||
return xfer_handle
|
||||
|
||||
def send_kvcache_slice(
|
||||
self,
|
||||
peer_name: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
dst_gpu_id: int,
|
||||
notif: str,
|
||||
prefill_tp_size: int,
|
||||
decode_tp_size: int,
|
||||
decode_tp_rank: int,
|
||||
dst_kv_item_len: int,
|
||||
):
|
||||
# Get configuration from kv_args
|
||||
local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
|
||||
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
|
||||
num_kv_heads = self.kv_args.kv_head_num
|
||||
|
||||
# Calculate head distribution
|
||||
src_heads_per_rank = num_kv_heads
|
||||
dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
|
||||
|
||||
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
page_size = self.kv_args.page_size
|
||||
|
||||
bytes_per_head_slice_to_send = (
|
||||
dst_kv_item_len // page_size // dst_heads_per_rank
|
||||
)
|
||||
|
||||
# Determine which heads to send
|
||||
if prefill_tp_size > decode_tp_size:
|
||||
# Multiple prefill ranks to one decode rank
|
||||
src_head_start_offset = 0
|
||||
num_heads_to_send = src_heads_per_rank
|
||||
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
||||
else:
|
||||
# Send KVCache from 1 prefill instance to multiple decode instances
|
||||
src_head_start_offset = (
|
||||
dst_tp_rank_in_group * dst_heads_per_rank
|
||||
) % src_heads_per_rank
|
||||
num_heads_to_send = dst_heads_per_rank
|
||||
dst_head_start_offset = 0
|
||||
|
||||
# Create transfer descriptors
|
||||
src_addrs = []
|
||||
dst_addrs = []
|
||||
|
||||
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
||||
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
||||
|
||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
||||
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
||||
|
||||
# Calculate precise byte offset and length for the sub-slice within the token
|
||||
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
||||
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
||||
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
||||
|
||||
src_dst_ptr_pairs = [
|
||||
(
|
||||
src_k_ptrs[layer_id],
|
||||
dst_k_ptrs[layer_id],
|
||||
)
|
||||
for layer_id in range(len(src_k_ptrs))
|
||||
] + [
|
||||
(
|
||||
src_v_ptrs[layer_id],
|
||||
dst_v_ptrs[layer_id],
|
||||
)
|
||||
for layer_id in range(len(src_v_ptrs))
|
||||
]
|
||||
|
||||
src_addrs = []
|
||||
dst_addrs = []
|
||||
|
||||
# Calculate strides for a single token slot
|
||||
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
||||
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
||||
|
||||
for src_ptr, dst_ptr in src_dst_ptr_pairs:
|
||||
for i in range(len(prefill_kv_indices)):
|
||||
prefill_page_idx = int(prefill_kv_indices[i])
|
||||
decode_page_idx = int(dst_kv_indices[i])
|
||||
|
||||
# Get the starting addresses for the current src and dst pages
|
||||
src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
|
||||
dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
|
||||
|
||||
# Iterate through each valid token slot within the current page
|
||||
for token_slot_in_page in range(page_size):
|
||||
# Calculate the start address of the current token slot
|
||||
src_token_slot_start_addr = (
|
||||
src_page_start_addr
|
||||
+ token_slot_in_page * bytes_per_token_on_prefill
|
||||
)
|
||||
dst_token_slot_start_addr = (
|
||||
dst_page_start_addr
|
||||
+ token_slot_in_page * bytes_per_token_on_decode
|
||||
)
|
||||
|
||||
# Calculate final src and dst addresses by applying head-slice offsets
|
||||
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
||||
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
||||
|
||||
src_addrs.append(
|
||||
(
|
||||
src_slice_addr,
|
||||
heads_bytes_per_token_to_send,
|
||||
self.kv_args.gpu_id,
|
||||
)
|
||||
)
|
||||
dst_addrs.append(
|
||||
(dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
|
||||
)
|
||||
|
||||
# Use NIXL agent for transfer
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
||||
|
||||
xfer_handle = self.agent.initialize_xfer(
|
||||
"WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
|
||||
)
|
||||
if not xfer_handle:
|
||||
raise Exception("Failed to create sliced KV transfer")
|
||||
|
||||
state = self.agent.transfer(xfer_handle)
|
||||
if state == "ERR":
|
||||
raise Exception("Failed to post sliced KV transfer")
|
||||
|
||||
return xfer_handle
|
||||
|
||||
def send_aux(
|
||||
self,
|
||||
peer_name: str,
|
||||
prefill_aux_index: int,
|
||||
dst_aux_ptrs: list[int],
|
||||
dst_aux_index: int,
|
||||
notif: str,
|
||||
):
|
||||
# Make descs
|
||||
aux_item_len = self.kv_args.aux_item_lens[0]
|
||||
prefill_aux_addr = (
|
||||
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
|
||||
)
|
||||
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
||||
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
||||
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
||||
# Transfer data
|
||||
xfer_handle = self.agent.initialize_xfer(
|
||||
"WRITE",
|
||||
src_descs,
|
||||
dst_descs,
|
||||
peer_name,
|
||||
notif.encode("ascii"), # type: ignore
|
||||
)
|
||||
if not xfer_handle:
|
||||
raise Exception("KVSender failed to create transfer")
|
||||
state = self.agent.transfer(xfer_handle)
|
||||
if state == "ERR":
|
||||
raise Exception("KVSender failed to post transfer")
|
||||
return xfer_handle
|
||||
|
||||
def add_transfer_request(
|
||||
self,
|
||||
bootstrap_room: int,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
index_slice: slice,
|
||||
is_last: bool,
|
||||
chunk_id: int,
|
||||
aux_index: Optional[int] = None,
|
||||
):
|
||||
assert self.disaggregation_mode == DisaggregationMode.PREFILL
|
||||
assert not is_last or (is_last and aux_index is not None)
|
||||
|
||||
reqs_to_be_processed = self.transfer_infos[bootstrap_room].values()
|
||||
handles = []
|
||||
for req in reqs_to_be_processed:
|
||||
assert bootstrap_room == req.room
|
||||
if req.is_dummy():
|
||||
continue
|
||||
|
||||
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
||||
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
||||
assert req.agent_name in self.decode_kv_args_table
|
||||
|
||||
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
||||
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
||||
|
||||
if decode_tp_size == self.tp_size:
|
||||
kv_xfer_handle = self.send_kvcache(
|
||||
req.agent_name,
|
||||
kv_indices,
|
||||
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
||||
chunked_dst_kv_indice,
|
||||
self.decode_kv_args_table[req.agent_name].gpu_id,
|
||||
notif,
|
||||
)
|
||||
else:
|
||||
kv_xfer_handle = self.send_kvcache_slice(
|
||||
req.agent_name,
|
||||
kv_indices,
|
||||
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
||||
chunked_dst_kv_indice,
|
||||
self.decode_kv_args_table[req.agent_name].gpu_id,
|
||||
notif,
|
||||
prefill_tp_size=self.tp_size,
|
||||
decode_tp_size=decode_tp_size,
|
||||
decode_tp_rank=self.decode_kv_args_table[
|
||||
req.agent_name
|
||||
].decode_tp_rank,
|
||||
dst_kv_item_len=self.decode_kv_args_table[
|
||||
req.agent_name
|
||||
].dst_kv_item_len,
|
||||
)
|
||||
|
||||
handles.append(kv_xfer_handle)
|
||||
# Only the last chunk we need to send the aux data.
|
||||
if is_last:
|
||||
assert aux_index is not None
|
||||
aux_xfer_handle = self.send_aux(
|
||||
req.agent_name,
|
||||
aux_index,
|
||||
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
|
||||
req.dst_aux_index,
|
||||
str(req.room) + "_aux",
|
||||
)
|
||||
handles.append(aux_xfer_handle)
|
||||
if is_last:
|
||||
del self.transfer_infos[bootstrap_room]
|
||||
return handles
|
||||
|
||||
def update_transfer_status(self):
|
||||
# Process notifications from received transfers.
|
||||
notif_map = self.agent.get_new_notifs()
|
||||
for peer_name, messages in notif_map.items():
|
||||
# We could also check that self.bootstrap_info['agent_name'] matches
|
||||
# the message sender. But the bootstrap room alone should be
|
||||
# sufficient to map the status.
|
||||
for msg in messages:
|
||||
components = msg.decode("ascii").split("_")
|
||||
room = int(components[0])
|
||||
if components[1] == "kv":
|
||||
chunk_id = int(components[2])
|
||||
is_last = bool(int(components[3]))
|
||||
self.transfer_statuses[room].received_kvs.add(chunk_id)
|
||||
if is_last:
|
||||
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
|
||||
elif components[1] == "aux":
|
||||
self.transfer_statuses[room].received_aux = True
|
||||
|
||||
def check_transfer_done(self, room: int):
|
||||
if room not in self.transfer_statuses:
|
||||
return False
|
||||
return self.transfer_statuses[room].is_done()
|
||||
|
||||
def _bind_server_socket(self):
|
||||
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
||||
|
||||
def _start_bootstrap_thread(self):
|
||||
self._bind_server_socket()
|
||||
|
||||
def bootstrap_thread():
|
||||
"""This thread recvs transfer info from the decode engine"""
|
||||
while True:
|
||||
waiting_req_bytes = self.server_socket.recv_multipart()
|
||||
logger.debug(
|
||||
f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}"
|
||||
)
|
||||
assert (
|
||||
waiting_req_bytes[0] == GUARD
|
||||
), f"First message should be {GUARD}. Foreign traffic?"
|
||||
waiting_req_bytes = waiting_req_bytes[1:]
|
||||
room = waiting_req_bytes[0].decode("ascii")
|
||||
agent_name = waiting_req_bytes[3].decode("ascii")
|
||||
if room == "None":
|
||||
# Register new peer and save KV base pointers.
|
||||
self._add_remote_peer(
|
||||
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
||||
)
|
||||
logger.debug(f"Register KVArgs from {agent_name} successfully")
|
||||
continue
|
||||
room = int(room)
|
||||
if room not in self.transfer_infos:
|
||||
self.transfer_infos[room] = {}
|
||||
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
|
||||
waiting_req_bytes
|
||||
)
|
||||
required_dst_info_num = self.transfer_infos[room][
|
||||
agent_name
|
||||
].required_dst_info_num
|
||||
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
|
||||
if len(self.transfer_infos[room]) == required_dst_info_num:
|
||||
logger.debug(f"{room=} is bootstrapped")
|
||||
self.update_status(room, KVPoll.WaitingForInput)
|
||||
|
||||
threading.Thread(target=bootstrap_thread).start()
|
||||
|
||||
|
||||
class NixlKVSender(BaseKVSender):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mgr: NixlKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: int,
|
||||
dest_tp_ranks: List[int],
|
||||
pp_rank: int,
|
||||
):
|
||||
self.kv_mgr = mgr
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.aux_index = None
|
||||
self.bootstrap_server_url = bootstrap_addr
|
||||
self.xfer_handles = []
|
||||
self.has_sent = False
|
||||
self.chunk_id = 0
|
||||
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
||||
# inner state
|
||||
self.curr_idx = 0
|
||||
|
||||
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
||||
self.num_kv_indices = num_kv_indices
|
||||
self.aux_index = aux_index
|
||||
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
):
|
||||
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||
self.curr_idx += len(kv_indices)
|
||||
is_last = self.curr_idx == self.num_kv_indices
|
||||
|
||||
new_xfer_handles = self.kv_mgr.add_transfer_request(
|
||||
self.bootstrap_room,
|
||||
kv_indices,
|
||||
index_slice,
|
||||
is_last,
|
||||
self.chunk_id,
|
||||
self.aux_index,
|
||||
)
|
||||
self.xfer_handles.extend(new_xfer_handles)
|
||||
self.chunk_id += 1
|
||||
if is_last:
|
||||
self.has_sent = True
|
||||
del self.kv_mgr.request_status[self.bootstrap_room]
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if not self.has_sent:
|
||||
return self.kv_mgr.check_status(self.bootstrap_room)
|
||||
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
|
||||
if all([x == "DONE" for x in states]):
|
||||
return KVPoll.Success # type: ignore
|
||||
if any([x == "ERR" for x in states]):
|
||||
raise Exception("KVSender transfer encountered an error.")
|
||||
return KVPoll.WaitingForInput # type: ignore
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVSender Exception")
|
||||
|
||||
|
||||
class NixlKVReceiver(CommonKVReceiver):
|
||||
def __init__(
|
||||
self,
|
||||
mgr: NixlKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
prefill_dp_rank: Optional[int] = None,
|
||||
):
|
||||
self.started_transfer = False
|
||||
self.conclude_state = None
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
logger.debug(
|
||||
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||
)
|
||||
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||
is_dummy = bootstrap_info["is_dummy"]
|
||||
logger.debug(
|
||||
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
||||
)
|
||||
with lock:
|
||||
sock.send_multipart(
|
||||
[
|
||||
GUARD,
|
||||
str(self.bootstrap_room).encode("ascii"),
|
||||
self.kv_mgr.local_ip.encode("ascii"),
|
||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||
self.kv_mgr.agent.name.encode("ascii"),
|
||||
kv_indices.tobytes() if not is_dummy else b"",
|
||||
str(aux_index).encode("ascii"),
|
||||
str(self.required_dst_info_num).encode("ascii"),
|
||||
]
|
||||
)
|
||||
|
||||
self.started_transfer = True
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.conclude_state is not None:
|
||||
return self.conclude_state
|
||||
if not self.started_transfer:
|
||||
return KVPoll.WaitingForInput # type: ignore
|
||||
|
||||
self.kv_mgr.update_transfer_status()
|
||||
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
||||
self.conclude_state = KVPoll.Success
|
||||
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
||||
return KVPoll.Success # type: ignore
|
||||
return KVPoll.WaitingForInput # type: ignore
|
||||
|
||||
def _register_kv_args(self):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
|
||||
packed_kv_data_ptrs = b"".join(
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||
)
|
||||
packed_aux_data_ptrs = b"".join(
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||
)
|
||||
|
||||
with lock:
|
||||
sock.send_multipart(
|
||||
[
|
||||
GUARD,
|
||||
"None".encode("ascii"),
|
||||
self.kv_mgr.local_ip.encode("ascii"),
|
||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||
self.kv_mgr.agent.name.encode("ascii"),
|
||||
self.kv_mgr.agent.get_agent_metadata(),
|
||||
packed_kv_data_ptrs,
|
||||
packed_aux_data_ptrs,
|
||||
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
||||
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
|
||||
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
|
||||
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
|
||||
]
|
||||
)
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVReceiver Exception")
|
||||
|
||||
|
||||
class NixlKVBootstrapServer(CommonKVBootstrapServer):
|
||||
pass
|
||||
867
python/sglang/srt/disaggregation/prefill.py
Normal file
867
python/sglang/srt/disaggregation/prefill.py
Normal file
@@ -0,0 +1,867 @@
|
||||
"""
|
||||
Life cycle of a request in the prefill server
|
||||
|
||||
1. Bootstrap Queue
|
||||
a. Initialize a sender for each request
|
||||
b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
|
||||
c. Poll senders to check bootstrap state
|
||||
d. Once bootstrap is complete, move request to Waiting Queue
|
||||
|
||||
2. Waiting Queue
|
||||
a. Use PrefillAdder to pop requests
|
||||
b. Run forward
|
||||
c. Add the request to Inflight Queue
|
||||
|
||||
3. Inflight Queue
|
||||
a. Poll (non-blocking) the sender of the request
|
||||
b. Once the transfer has finished, return the request
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
KVClassType,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
is_mla_backend,
|
||||
kv_to_page_indices,
|
||||
kv_to_page_num,
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||
from sglang.srt.utils import (
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
point_to_point_pyobj,
|
||||
require_mlp_sync,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PrefillBootstrapQueue:
|
||||
"""
|
||||
Store the requests in bootstrapping
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_to_kv_pool: KVCache,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
gpu_id: int,
|
||||
bootstrap_port: int,
|
||||
gloo_group: ProcessGroup,
|
||||
max_total_num_tokens: int,
|
||||
decode_tp_size: int,
|
||||
decode_dp_size: int,
|
||||
scheduler: Scheduler,
|
||||
pp_rank: int,
|
||||
pp_size: int,
|
||||
transfer_backend: TransferBackend,
|
||||
):
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.decode_tp_size = decode_tp_size
|
||||
self.decode_dp_size = decode_dp_size
|
||||
self.pp_rank = pp_rank
|
||||
self.pp_size = pp_size
|
||||
self.gpu_id = gpu_id
|
||||
self.bootstrap_port = bootstrap_port
|
||||
self.queue: List[Req] = []
|
||||
self.gloo_group = gloo_group
|
||||
self.max_total_num_tokens = max_total_num_tokens
|
||||
self.scheduler = scheduler
|
||||
self.transfer_backend = transfer_backend
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
|
||||
def _init_kv_manager(self) -> BaseKVManager:
|
||||
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
||||
kv_args = kv_args_class()
|
||||
kv_args.engine_rank = self.tp_rank
|
||||
kv_args.pp_rank = self.pp_rank
|
||||
kv_args.system_dp_rank = self.scheduler.dp_rank
|
||||
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
||||
kv_args.prefill_pp_size = self.pp_size
|
||||
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
if self.draft_token_to_kv_pool is not None:
|
||||
# We should also transfer draft model kv cache. The indices are
|
||||
# always shared with a target model.
|
||||
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
||||
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
kv_data_ptrs += draft_kv_data_ptrs
|
||||
kv_data_lens += draft_kv_data_lens
|
||||
kv_item_lens += draft_kv_item_lens
|
||||
|
||||
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
if not self.is_mla_backend:
|
||||
kv_args.kv_head_num = self.token_to_kv_pool.head_num
|
||||
kv_args.page_size = self.token_to_kv_pool.page_size
|
||||
|
||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
|
||||
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
||||
self.transfer_backend, KVClassType.MANAGER
|
||||
)
|
||||
kv_manager: BaseKVManager = kv_manager_class(
|
||||
kv_args,
|
||||
DisaggregationMode.PREFILL,
|
||||
self.scheduler.server_args,
|
||||
self.is_mla_backend,
|
||||
)
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req, num_kv_heads: int) -> None:
|
||||
if self._check_if_req_exceed_kv_capacity(req):
|
||||
return
|
||||
|
||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
||||
else:
|
||||
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||
|
||||
dest_tp_ranks = [self.tp_rank]
|
||||
|
||||
req.disagg_kv_sender = kv_sender_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
dest_tp_ranks=dest_tp_ranks,
|
||||
pp_rank=self.pp_rank,
|
||||
)
|
||||
self._process_req(req)
|
||||
self.queue.append(req)
|
||||
|
||||
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
||||
for req in reqs:
|
||||
self.add(req, num_kv_heads)
|
||||
|
||||
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
||||
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
||||
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
||||
logger.error(message)
|
||||
prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
|
||||
self.scheduler.stream_output([req], req.return_logprob)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_req(self, req: Req) -> None:
|
||||
"""
|
||||
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
||||
"""
|
||||
req.sampling_params.max_new_tokens = 1
|
||||
|
||||
def pop_bootstrapped(
|
||||
self,
|
||||
return_failed_reqs: bool = False,
|
||||
rids_to_check: Optional[List[str]] = None,
|
||||
) -> List[Req]:
|
||||
"""
|
||||
pop the reqs which has finished bootstrapping
|
||||
|
||||
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
|
||||
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
||||
"""
|
||||
|
||||
bootstrapped_reqs = []
|
||||
failed_reqs = []
|
||||
indices_to_remove = set()
|
||||
|
||||
if len(self.queue) == 0:
|
||||
if return_failed_reqs is False:
|
||||
return []
|
||||
else:
|
||||
return [], []
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if rids_to_check is not None:
|
||||
# if req not in reqs_info_to_check, skip
|
||||
if req.rid not in rids_to_check:
|
||||
continue
|
||||
# Either waiting for input or failed
|
||||
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
|
||||
|
||||
if poll == KVPoll.Bootstrapping:
|
||||
continue
|
||||
elif poll == KVPoll.Failed:
|
||||
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
||||
try:
|
||||
req.disagg_kv_sender.failure_exception()
|
||||
except Exception as e:
|
||||
error_message += f" with exception {e}"
|
||||
logger.error(error_message)
|
||||
prepare_abort(
|
||||
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
self.scheduler.stream_output([req], req.return_logprob)
|
||||
indices_to_remove.add(i)
|
||||
failed_reqs.append(req)
|
||||
if self.scheduler.enable_metrics:
|
||||
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
|
||||
continue
|
||||
|
||||
# KV.WaitingForInput - init here
|
||||
num_kv_indices = len(req.origin_input_ids)
|
||||
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
||||
break
|
||||
|
||||
req.metadata_buffer_index = (
|
||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||
)
|
||||
assert req.metadata_buffer_index is not None
|
||||
|
||||
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||
bootstrapped_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
self.queue = [
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
]
|
||||
|
||||
if return_failed_reqs is False:
|
||||
return bootstrapped_reqs
|
||||
else:
|
||||
return bootstrapped_reqs, failed_reqs
|
||||
|
||||
|
||||
class SchedulerDisaggregationPrefillMixin:
|
||||
"""
|
||||
Mixin for Scheduler to handle disaggregation prefill
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
|
||||
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result_disagg_prefill(batch, result)
|
||||
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.last_batch = batch
|
||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.last_batch = batch
|
||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def process_batch_result_disagg_prefill(
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
result: GenerationBatchResult,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
Adapted from process_batch_result_prefill
|
||||
"""
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req,
|
||||
) = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.extend_input_len_per_req,
|
||||
result.extend_logprob_start_len_per_req,
|
||||
)
|
||||
|
||||
logprob_pt = 0
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
if self.enable_overlap:
|
||||
# wait
|
||||
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
||||
launch_done
|
||||
)
|
||||
else:
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
|
||||
hidden_state_offset = 0
|
||||
for i, (req, next_token_id) in enumerate(
|
||||
zip(batch.reqs, next_token_ids, strict=True)
|
||||
):
|
||||
req: Req
|
||||
if req.is_chunked <= 0:
|
||||
# There is no output_ids for prefill
|
||||
req.output_ids.append(next_token_id)
|
||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||
self.disagg_prefill_inflight_queue.append(req)
|
||||
if (
|
||||
logits_output is not None
|
||||
and logits_output.hidden_states is not None
|
||||
):
|
||||
last_hidden_index = (
|
||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||
)
|
||||
req.hidden_states_tensor = (
|
||||
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
||||
)
|
||||
hidden_state_offset += extend_input_len_per_req[i]
|
||||
else:
|
||||
req.hidden_states_tensor = None
|
||||
if req.return_logprob:
|
||||
assert extend_logprob_start_len_per_req is not None
|
||||
assert extend_input_len_per_req is not None
|
||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||
extend_input_len = extend_input_len_per_req[i]
|
||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
||||
self.add_logprob_return_values(
|
||||
i,
|
||||
req,
|
||||
logprob_pt,
|
||||
next_token_ids,
|
||||
num_input_logprobs,
|
||||
logits_output,
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
self.send_kv_chunk(req, last_chunk=True)
|
||||
|
||||
if req.grammar is not None:
|
||||
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
||||
try:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
except ValueError as e:
|
||||
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
||||
# This can happen if the grammar is not set correctly or the token is invalid.
|
||||
error_message = f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
prepare_abort(
|
||||
req,
|
||||
error_message,
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
req.grammar.finished = req.finished()
|
||||
else:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_chunked -= 1
|
||||
|
||||
if req.return_logprob:
|
||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||
extend_input_len = extend_input_len_per_req[i]
|
||||
if extend_logprob_start_len < extend_input_len:
|
||||
# Update input logprobs.
|
||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
||||
self.add_input_logprob_return_values(
|
||||
i,
|
||||
req,
|
||||
logits_output,
|
||||
logprob_pt,
|
||||
num_input_logprobs,
|
||||
last_prefill_chunk=False,
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
|
||||
if self.enable_overlap:
|
||||
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
||||
|
||||
# We need to remove the sync in the following function for overlap schedule.
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
self.maybe_send_health_check_signal()
|
||||
|
||||
def process_disagg_prefill_inflight_queue(
|
||||
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
||||
) -> List[Req]:
|
||||
"""
|
||||
Poll the requests in the middle of transfer. If done, return the request.
|
||||
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
||||
"""
|
||||
if len(self.disagg_prefill_inflight_queue) == 0:
|
||||
return []
|
||||
|
||||
done_reqs = []
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
||||
self.attn_tp_cpu_group,
|
||||
)
|
||||
|
||||
undone_reqs: List[Req] = []
|
||||
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
|
||||
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
||||
|
||||
if rids_to_check is not None:
|
||||
if req.rid not in rids_to_check:
|
||||
undone_reqs.append(req)
|
||||
continue
|
||||
|
||||
assert poll == KVPoll.Success or poll == KVPoll.Failed
|
||||
|
||||
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
||||
undone_reqs.append(req)
|
||||
elif poll == KVPoll.Success: # transfer done
|
||||
self.tree_cache.cache_finished_req(req) # unlock the tree
|
||||
req.finished_reason = FINISH_LENGTH(length=0)
|
||||
# FIXME: clean up req's data in transfer engine
|
||||
if hasattr(req.disagg_kv_sender, "clear"):
|
||||
req.disagg_kv_sender.clear()
|
||||
done_reqs.append(req)
|
||||
elif poll == KVPoll.Failed:
|
||||
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
||||
try:
|
||||
req.disagg_kv_sender.failure_exception()
|
||||
except Exception as e:
|
||||
error_message += f" with exception {e}"
|
||||
logger.warning(error_message)
|
||||
self.tree_cache.cache_finished_req(req) # unlock the tree
|
||||
prepare_abort(
|
||||
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
done_reqs.append(req)
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector.increment_transfer_failed_reqs()
|
||||
else:
|
||||
assert False, f"Unexpected polling state {poll=}"
|
||||
|
||||
# Stream requests which have finished transfer
|
||||
self.stream_output(
|
||||
done_reqs,
|
||||
any(req.return_logprob for req in done_reqs),
|
||||
None,
|
||||
)
|
||||
for req in done_reqs:
|
||||
req: Req
|
||||
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
||||
req.metadata_buffer_index = -1
|
||||
|
||||
self.disagg_prefill_inflight_queue = undone_reqs
|
||||
|
||||
return done_reqs
|
||||
|
||||
def get_transferred_rids(self: Scheduler) -> List[str]:
|
||||
"""
|
||||
Used by PP, get the transferred rids but **do not pop**
|
||||
"""
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
||||
self.tp_worker.get_tp_group().cpu_group,
|
||||
)
|
||||
|
||||
transferred_rids: List[str] = []
|
||||
|
||||
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
||||
if poll == KVPoll.Success or poll == KVPoll.Failed:
|
||||
transferred_rids.append(req.rid)
|
||||
|
||||
return transferred_rids
|
||||
|
||||
def process_prefill_chunk(self: Scheduler) -> None:
|
||||
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||
if self.chunked_req:
|
||||
# Move the chunked request out of the batch so that we can merge
|
||||
# only finished requests to running_batch.
|
||||
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
||||
if self.enable_overlap:
|
||||
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
||||
self.chunked_req.tmp_end_idx = min(
|
||||
len(self.chunked_req.fill_ids),
|
||||
len(self.chunked_req.origin_input_ids),
|
||||
)
|
||||
else:
|
||||
self.send_kv_chunk(self.chunked_req)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def send_kv_chunk(
|
||||
self: Scheduler,
|
||||
req: Req,
|
||||
last_chunk: bool = False,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a prefilled chunk to the decode server
|
||||
"""
|
||||
page_size = self.token_to_kv_pool_allocator.page_size
|
||||
start_idx = req.start_send_idx
|
||||
end_idx = (
|
||||
end_idx
|
||||
if end_idx is not None
|
||||
else min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
)
|
||||
|
||||
if not last_chunk:
|
||||
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
||||
end_idx = end_idx - end_idx % page_size
|
||||
|
||||
kv_indices = (
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
req.start_send_idx = end_idx
|
||||
if last_chunk:
|
||||
self.disagg_metadata_buffers.set_buf(req)
|
||||
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||
if len(page_indices) == 0:
|
||||
logger.info(
|
||||
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
||||
)
|
||||
return
|
||||
req.disagg_kv_sender.send(page_indices)
|
||||
|
||||
# PP
|
||||
@DynamicGradMode()
|
||||
def event_loop_pp_disagg_prefill(self: Scheduler):
|
||||
"""
|
||||
An event loop for the prefill server in pipeline parallelism.
|
||||
|
||||
Rules:
|
||||
1. Each stage runs in the same order and is notified by the previous stage.
|
||||
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
||||
|
||||
Regular Schedule:
|
||||
====================================================================
|
||||
Stage i | Stage i+1
|
||||
send ith req | recv ith req
|
||||
send ith proxy | recv ith proxy
|
||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||
====================================================================
|
||||
|
||||
Prefill Server Schedule:
|
||||
====================================================================
|
||||
Stage i | Stage i+1
|
||||
send ith req | recv ith req
|
||||
send ith bootstrap req | recv ith bootstrap req
|
||||
send ith transferred req | recv ith transferred req
|
||||
send ith proxy | recv ith proxy
|
||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||
send prev (i+1)th release req | recv prev (i+1)th release req
|
||||
====================================================================
|
||||
|
||||
There are two additional elements compared to the regular schedule:
|
||||
|
||||
1. Bootstrap Requests:
|
||||
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
||||
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
||||
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
||||
|
||||
2. Transferred Requests + Release Requests:
|
||||
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
||||
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
||||
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
||||
"""
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
|
||||
mbs = [None] * self.pp_size
|
||||
last_mbs = [None] * self.pp_size
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
bids = [None] * self.pp_size
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
|
||||
# Either success or failed
|
||||
bootstrapped_rids: List[str] = []
|
||||
transferred_rids: List[str] = []
|
||||
release_rids: Optional[List[str]] = None
|
||||
|
||||
# transferred microbatch
|
||||
tmbs = [None] * self.pp_size
|
||||
|
||||
ENABLE_RELEASE = True # For debug
|
||||
|
||||
while True:
|
||||
server_is_idle = True
|
||||
|
||||
for mb_id in range(self.pp_size):
|
||||
self.running_batch = self.running_mbs[mb_id]
|
||||
self.last_batch = last_mbs[mb_id]
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
# First rank, pop the bootstrap reqs from the bootstrap queue
|
||||
bootstrapped_reqs, failed_reqs = (
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||
return_failed_reqs=True
|
||||
)
|
||||
)
|
||||
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
||||
req.rid for req in failed_reqs
|
||||
]
|
||||
self.waiting_queue.extend(bootstrapped_reqs)
|
||||
else:
|
||||
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
||||
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
||||
bootstrapped_reqs = (
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||
rids_to_check=bootstrapped_rids
|
||||
)
|
||||
)
|
||||
self.waiting_queue.extend(bootstrapped_reqs)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
transferred_rids = self.get_transferred_rids()
|
||||
# if other ranks,
|
||||
else:
|
||||
# 1. recv previous stage's transferred reqs info
|
||||
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
||||
# 2. get the current stage's transferred reqs info
|
||||
curr_transferred_rids = self.get_transferred_rids()
|
||||
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
||||
transferred_rids = list(
|
||||
set(prev_transferred_rids) & set(curr_transferred_rids)
|
||||
)
|
||||
|
||||
tmbs[mb_id] = transferred_rids
|
||||
|
||||
self.process_prefill_chunk()
|
||||
mbs[mb_id] = self.get_new_batch_prefill()
|
||||
self.running_mbs[mb_id] = self.running_batch
|
||||
|
||||
self.cur_batch = mbs[mb_id]
|
||||
if self.cur_batch:
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids, bids[mb_id] = (
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
# send the output from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
if ENABLE_RELEASE:
|
||||
if self.pp_group.is_last_rank:
|
||||
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
||||
release_rids = transferred_rids
|
||||
# send to the first rank
|
||||
self.send_pyobj_to_next_stage(release_rids)
|
||||
|
||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||
next_mb_id = (mb_id + 1) % self.pp_size
|
||||
next_pp_outputs = None
|
||||
next_release_rids = None
|
||||
|
||||
if mbs[next_mb_id] is not None:
|
||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||
self.pp_group.recv_tensor_dict(
|
||||
all_gather_group=self.attn_tp_group
|
||||
)
|
||||
)
|
||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||
output_result = GenerationBatchResult(
|
||||
logits_output=None,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=None,
|
||||
extend_logprob_start_len_per_req=None,
|
||||
bid=bids[next_mb_id],
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result_disagg_prefill(
|
||||
mbs[next_mb_id], output_result
|
||||
)
|
||||
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
if ENABLE_RELEASE:
|
||||
if tmbs[next_mb_id] is not None:
|
||||
# recv consensus rids from the previous rank
|
||||
next_release_rids = self.recv_pyobj_from_prev_stage()
|
||||
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
||||
|
||||
# carry the outputs to the next stage
|
||||
if not self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
bids[mb_id] = result.bid
|
||||
if pp_outputs:
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
if ENABLE_RELEASE:
|
||||
if release_rids is not None:
|
||||
self.send_pyobj_to_next_stage(release_rids)
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
# send out reqs to the next stage
|
||||
self.send_pyobj_to_next_stage(recv_reqs)
|
||||
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
||||
self.send_pyobj_to_next_stage(transferred_rids)
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
pp_outputs = next_pp_outputs
|
||||
release_rids = next_release_rids
|
||||
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
if not ENABLE_RELEASE:
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
def send_pyobj_to_next_stage(self, data):
|
||||
if self.attn_tp_rank == 0:
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
point_to_point_pyobj(
|
||||
data,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
self.world_group.device_group,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
|
||||
)
|
||||
|
||||
def recv_pyobj_from_prev_stage(self):
|
||||
if self.attn_tp_rank == 0:
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
data = point_to_point_pyobj(
|
||||
[],
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
self.world_group.device_group,
|
||||
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
)
|
||||
else:
|
||||
data = None
|
||||
|
||||
if self.tp_size != 1:
|
||||
data = broadcast_pyobj(
|
||||
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
|
||||
)
|
||||
return data
|
||||
329
python/sglang/srt/disaggregation/utils.py
Normal file
329
python/sglang/srt/disaggregation/utils.py
Normal file
@@ -0,0 +1,329 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.utils import is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
#########################
|
||||
# Constants & Enums
|
||||
#########################
|
||||
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
|
||||
|
||||
|
||||
class DisaggregationMode(Enum):
|
||||
NULL = "null"
|
||||
PREFILL = "prefill"
|
||||
DECODE = "decode"
|
||||
|
||||
|
||||
#########################
|
||||
# Synchronization
|
||||
#########################
|
||||
|
||||
# env var for testing failure, convert to float explicitly
|
||||
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
||||
|
||||
|
||||
def poll_and_all_reduce(pollers, gloo_group):
|
||||
# at a certain prob, the poll is failed to simulate failure
|
||||
if FAILURE_PROB > 0:
|
||||
from sglang.srt.disaggregation.base import KVPoll
|
||||
|
||||
polls = [
|
||||
int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
|
||||
for poller in pollers
|
||||
]
|
||||
else:
|
||||
polls = [int(poller.poll()) for poller in pollers]
|
||||
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
||||
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
||||
return tensor_to_reduce.tolist()
|
||||
|
||||
|
||||
#########################
|
||||
# Metadata Buffers
|
||||
#########################
|
||||
|
||||
|
||||
class ReqToMetadataIdxAllocator:
|
||||
"""A memory pool that maps a request to its first output token location."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
):
|
||||
self.size = size
|
||||
self.free_slots = deque(list(range(size)))
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self) -> Optional[int]:
|
||||
if len(self.free_slots) == 0:
|
||||
return None
|
||||
|
||||
return self.free_slots.popleft()
|
||||
|
||||
def free(self, free_index: int):
|
||||
self.free_slots.append(free_index)
|
||||
|
||||
|
||||
class MetadataBuffers:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
max_top_logprobs_num: int = 128,
|
||||
custom_mem_pool: torch.cuda.MemPool = None,
|
||||
):
|
||||
self.custom_mem_pool = custom_mem_pool
|
||||
device = "cpu"
|
||||
if is_npu():
|
||||
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
|
||||
device = "npu"
|
||||
elif self.custom_mem_pool:
|
||||
# TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
|
||||
device = "cpu"
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
else nullcontext()
|
||||
):
|
||||
# TODO: abort top_logprobs_num > 128 in PD
|
||||
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
self.output_token_logprobs_idx = torch.zeros(
|
||||
(size, 16), dtype=torch.int32, device=device
|
||||
)
|
||||
self.output_top_logprobs_val = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.float32, device=device
|
||||
)
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||
)
|
||||
self.output_hidden_states = torch.zeros(
|
||||
(size, hidden_size), dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def get_buf_infos(self):
|
||||
ptrs = [
|
||||
self.output_ids.data_ptr(),
|
||||
self.output_token_logprobs_val.data_ptr(),
|
||||
self.output_token_logprobs_idx.data_ptr(),
|
||||
self.output_top_logprobs_val.data_ptr(),
|
||||
self.output_top_logprobs_idx.data_ptr(),
|
||||
self.output_hidden_states.data_ptr(),
|
||||
]
|
||||
data_lens = [
|
||||
self.output_ids.nbytes,
|
||||
self.output_token_logprobs_val.nbytes,
|
||||
self.output_token_logprobs_idx.nbytes,
|
||||
self.output_top_logprobs_val.nbytes,
|
||||
self.output_top_logprobs_idx.nbytes,
|
||||
self.output_hidden_states.nbytes,
|
||||
]
|
||||
item_lens = [
|
||||
self.output_ids[0].nbytes,
|
||||
self.output_token_logprobs_val[0].nbytes,
|
||||
self.output_token_logprobs_idx[0].nbytes,
|
||||
self.output_top_logprobs_val[0].nbytes,
|
||||
self.output_top_logprobs_idx[0].nbytes,
|
||||
self.output_hidden_states[0].nbytes,
|
||||
]
|
||||
return ptrs, data_lens, item_lens
|
||||
|
||||
def get_buf(self, idx: int):
|
||||
return (
|
||||
self.output_ids[idx],
|
||||
self.output_token_logprobs_val[idx],
|
||||
self.output_token_logprobs_idx[idx],
|
||||
self.output_top_logprobs_val[idx],
|
||||
self.output_top_logprobs_idx[idx],
|
||||
self.output_hidden_states[idx],
|
||||
)
|
||||
|
||||
def set_buf(self, req: Req):
|
||||
|
||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||
if req.return_logprob:
|
||||
if req.output_token_logprobs_val: # not none or empty list
|
||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||
req.output_token_logprobs_val[0]
|
||||
)
|
||||
if req.output_token_logprobs_idx: # not none or empty list
|
||||
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
|
||||
req.output_token_logprobs_idx[0]
|
||||
)
|
||||
|
||||
if req.output_top_logprobs_val: # not none or empty list
|
||||
self.output_top_logprobs_val[req.metadata_buffer_index][
|
||||
: len(req.output_top_logprobs_val[0])
|
||||
] = torch.tensor(
|
||||
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
|
||||
)
|
||||
if req.output_top_logprobs_idx: # not none or empty list
|
||||
self.output_top_logprobs_idx[req.metadata_buffer_index][
|
||||
: len(req.output_top_logprobs_idx[0])
|
||||
] = torch.tensor(
|
||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
# for PD + spec decode
|
||||
if req.hidden_states_tensor is not None:
|
||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||
req.hidden_states_tensor
|
||||
)
|
||||
|
||||
|
||||
#########################
|
||||
# Transfer Backend
|
||||
#########################
|
||||
|
||||
|
||||
class TransferBackend(Enum):
|
||||
MOONCAKE = "mooncake"
|
||||
NIXL = "nixl"
|
||||
ASCEND = "ascend"
|
||||
FAKE = "fake"
|
||||
|
||||
|
||||
class KVClassType(Enum):
|
||||
KVARGS = "kvargs"
|
||||
MANAGER = "manager"
|
||||
SENDER = "sender"
|
||||
RECEIVER = "receiver"
|
||||
BOOTSTRAP_SERVER = "bootstrap_server"
|
||||
|
||||
|
||||
def get_kv_class(
|
||||
transfer_backend: TransferBackend, class_type: KVClassType
|
||||
) -> Optional[Type]:
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
if transfer_backend == TransferBackend.MOONCAKE:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.mooncake import (
|
||||
MooncakeKVBootstrapServer,
|
||||
MooncakeKVManager,
|
||||
MooncakeKVReceiver,
|
||||
MooncakeKVSender,
|
||||
)
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.KVARGS: KVArgs,
|
||||
KVClassType.MANAGER: MooncakeKVManager,
|
||||
KVClassType.SENDER: MooncakeKVSender,
|
||||
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
elif transfer_backend == TransferBackend.ASCEND:
|
||||
from sglang.srt.disaggregation.ascend import (
|
||||
AscendKVBootstrapServer,
|
||||
AscendKVManager,
|
||||
AscendKVReceiver,
|
||||
AscendKVSender,
|
||||
)
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.KVARGS: KVArgs,
|
||||
KVClassType.MANAGER: AscendKVManager,
|
||||
KVClassType.SENDER: AscendKVSender,
|
||||
KVClassType.RECEIVER: (AscendKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
elif transfer_backend == TransferBackend.NIXL:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.nixl import (
|
||||
NixlKVBootstrapServer,
|
||||
NixlKVManager,
|
||||
NixlKVReceiver,
|
||||
NixlKVSender,
|
||||
)
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.KVARGS: KVArgs,
|
||||
KVClassType.MANAGER: NixlKVManager,
|
||||
KVClassType.SENDER: NixlKVSender,
|
||||
KVClassType.RECEIVER: (NixlKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
elif transfer_backend == TransferBackend.FAKE:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.KVARGS: KVArgs,
|
||||
KVClassType.SENDER: FakeKVSender,
|
||||
KVClassType.RECEIVER: (FakeKVReceiver),
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
|
||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||
|
||||
|
||||
#########################
|
||||
# KV Pages
|
||||
#########################
|
||||
|
||||
|
||||
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
||||
# 1. The page is guaranteed to be full except the last page.
|
||||
# 2. page index = kv_index // page_size
|
||||
# The return vector is kv_indices[::page_size] // page_size
|
||||
if page_size == 1: # shortcut
|
||||
return kv_indices
|
||||
|
||||
return kv_indices[::page_size] // page_size
|
||||
|
||||
|
||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||
# ceil(num_kv_indices / page_size)
|
||||
return (num_kv_indices + page_size - 1) // page_size
|
||||
|
||||
|
||||
#########################
|
||||
# Misc
|
||||
#########################
|
||||
|
||||
|
||||
def is_mla_backend(target_kv_pool) -> bool:
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
|
||||
return isinstance(target_kv_pool, MLATokenToKVPool)
|
||||
|
||||
|
||||
def prepare_abort(req: Req, error_message: str, status_code=None):
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||
|
||||
# populate finish metadata and stream output
|
||||
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
||||
|
||||
if req.return_logprob:
|
||||
req.input_token_logprobs_val = []
|
||||
req.input_token_logprobs_idx = []
|
||||
req.input_top_logprobs_val = []
|
||||
req.input_top_logprobs_idx = []
|
||||
req.input_token_ids_logprobs_val = []
|
||||
req.input_token_ids_logprobs_idx = []
|
||||
3
python/sglang/srt/distributed/__init__.py
Normal file
3
python/sglang/srt/distributed/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from sglang.srt.distributed.communication_op import *
|
||||
from sglang.srt.distributed.parallel_state import *
|
||||
from sglang.srt.distributed.utils import *
|
||||
35
python/sglang/srt/distributed/communication_op.py
Normal file
35
python/sglang/srt/distributed/communication_op.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from .parallel_state import get_tp_group
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
return get_tp_group().all_reduce(input_)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(
|
||||
input_: torch.Tensor, dim: int = -1
|
||||
) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
return get_tp_group().all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(
|
||||
input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Gather the input tensor across model parallel group."""
|
||||
return get_tp_group().gather(input_, dst, dim)
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
return tensor_dict
|
||||
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
|
||||
@@ -0,0 +1,183 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
|
||||
|
||||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# === export types and functions from cudart to Python ===
|
||||
# for the original cudart definition, please check
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found = False
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = line.index("/")
|
||||
path = line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(
|
||||
lib_name
|
||||
), f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
|
||||
# cudaError_t cudaDeviceSynchronize ( void )
|
||||
Function("cudaDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("cudaDeviceReset", cudaError_t, []),
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function(
|
||||
"cudaMalloc",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
|
||||
),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function(
|
||||
"cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
|
||||
),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function(
|
||||
"cudaMemcpy",
|
||||
cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
|
||||
),
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function(
|
||||
"cudaIpcGetMemHandle",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
|
||||
),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function(
|
||||
"cudaIpcOpenMemHandle",
|
||||
cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
|
||||
),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libcudart")
|
||||
assert so_file is not None, "libcudart is not loaded in the current process"
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in CudaRTLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.cudaGetErrorString(result)
|
||||
raise RuntimeError(f"CUDART error: {error_str}")
|
||||
|
||||
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
|
||||
|
||||
def cudaSetDevice(self, device: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
|
||||
|
||||
def cudaDeviceSynchronize(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
|
||||
|
||||
def cudaDeviceReset(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
|
||||
|
||||
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
|
||||
return devPtr
|
||||
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(
|
||||
self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
|
||||
) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
|
||||
)
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
|
||||
)
|
||||
)
|
||||
return devPtr
|
||||
@@ -0,0 +1,421 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check,
|
||||
is_full_nvlink,
|
||||
is_weak_contiguous,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce and not _is_hip:
|
||||
# Use vLLM custom allreduce
|
||||
ops.meta_size()
|
||||
else:
|
||||
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
|
||||
import sgl_kernel
|
||||
custom_ar = True
|
||||
except Exception:
|
||||
# For CPUs
|
||||
custom_ar = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
# SGLANG_SKIP_P2P_CHECK can be set to False in sglang
|
||||
SGLANG_SKIP_P2P_CHECK = os.getenv("SGLANG_SKIP_P2P_CHECK", "0") == "1"
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if SGLANG_SKIP_P2P_CHECK:
|
||||
logger.info("Skipping P2P check and trusting the driver's P2P report.")
|
||||
return torch.cuda.can_device_access_peer(rank, i)
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
_MAX_CAR_SIZE = 8192 * 1024
|
||||
if _is_hip:
|
||||
# crossover is at 16MB buffer size for ROCm
|
||||
_MAX_CAR_SIZE = 2 * 8192 * 1024
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_size=_MAX_CAR_SIZE,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not custom_ar:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-cuda environment
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "CustomAllreduce should be attached to a non-NCCL group."
|
||||
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
" spans across nodes."
|
||||
)
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size,
|
||||
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
if _is_cuda or _is_hip:
|
||||
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
|
||||
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
"specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
|
||||
if not _is_hip and not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly."
|
||||
)
|
||||
return
|
||||
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = full_nvlink
|
||||
|
||||
if not _is_hip:
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
self.meta_ptrs = self.create_shared_buffer(
|
||||
ops.meta_size() + max_size, group=group
|
||||
)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
|
||||
)
|
||||
ops.register_buffer(self._ptr, self.buffer_ptrs)
|
||||
else:
|
||||
# meta data buffers need to be "uncached" for signal on MI200
|
||||
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
|
||||
self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
|
||||
handle = ops.get_meta_buffer_ipc_handle(self.meta)
|
||||
shard_data = (
|
||||
bytes(handle), # ipc handle to base ptr
|
||||
0, # offset of base ptr
|
||||
)
|
||||
handles, offsets = self._gather_ipc_meta(shard_data)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self._ptr = ops.init_custom_ar(
|
||||
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
||||
)
|
||||
self.register_buffer(self.buffer)
|
||||
|
||||
self.disabled = False
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int, group: Optional[ProcessGroup] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a shared buffer and returns a list of pointers
|
||||
representing the buffer on all processes in the group.
|
||||
"""
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value) # type: ignore
|
||||
else:
|
||||
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
||||
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: List[int], group: Optional[ProcessGroup] = None
|
||||
) -> None:
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
# _share_cuda_() doesn't accept meta buffer not allocated from
|
||||
# PyTorch cache allocator, use direct HIP call to get IPC handle
|
||||
handle = ops.get_meta_buffer_ipc_handle(inp)
|
||||
shard_data = (
|
||||
bytes(handle), # ipc handle to base ptr
|
||||
0, # offset of base ptr
|
||||
)
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
# Note: don't use `[[None]] * self.world_size` here
|
||||
# because it will create a list of the same reference
|
||||
all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]
|
||||
all_data[self.rank][0] = shard_data
|
||||
|
||||
ranks = dist.get_process_group_ranks(group=self.group)
|
||||
ranks.sort()
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
|
||||
# we cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0][0]) # type: ignore
|
||||
offsets.append(all_data[i][0][1]) # type: ignore
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
handles, offsets = self._get_ipc_meta(inp)
|
||||
ops.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
if _is_hip:
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
else:
|
||||
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
# We cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
all_data = [
|
||||
[None, None] for _ in range(dist.get_world_size(group=self.group))
|
||||
]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(
|
||||
all_data[i], src=rank, group=self.group, device="cpu"
|
||||
)
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
ops.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
if not is_weak_contiguous(inp):
|
||||
return False
|
||||
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
|
||||
# little performance improvement over NCCL.
|
||||
if not _is_hip:
|
||||
if self.world_size == 2 or self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
if _is_hip:
|
||||
if self.full_nvlink:
|
||||
return inp_size < self.max_size
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
# or, in the context of cuda graphs, register_graph_buffers
|
||||
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.all_reduce_reg(self._ptr, inp, out)
|
||||
return out
|
||||
|
||||
# all reduce, assuming inp tensor is NOT IPC registered
|
||||
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
return out
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: torch.Tensor = None,
|
||||
registered: bool = False,
|
||||
):
|
||||
"""Performs an out-of-place all reduce.
|
||||
|
||||
If registered is True, this assumes inp's pointer is already
|
||||
IPC-registered. Otherwise, inp is first copied into a pre-registered
|
||||
buffer.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
if registered:
|
||||
ops.all_reduce(self._ptr, inp, out, 0, 0)
|
||||
else:
|
||||
ops.all_reduce(
|
||||
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
|
||||
)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if _is_hip:
|
||||
return self.all_reduce_reg(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=True)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return torch.zeros_like(input)
|
||||
else:
|
||||
if _is_hip:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
return self.all_reduce_unreg(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
ops.dispose(self._ptr)
|
||||
if _is_cuda:
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
self._ptr = 0
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
@@ -0,0 +1,386 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
from itertools import product
|
||||
from typing import Callable, Dict, List, Optional, Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if _is_cuda:
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
if _is_hip:
|
||||
try:
|
||||
from amdsmi import (
|
||||
AmdSmiException,
|
||||
amdsmi_get_processor_handles,
|
||||
amdsmi_init,
|
||||
amdsmi_shut_down,
|
||||
amdsmi_topo_get_link_type,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import amdsmi with %r", e)
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def update_environment_variables(envs: Dict[str, str]):
|
||||
for k, v in envs.items():
|
||||
if k in os.environ and os.environ[k] != v:
|
||||
logger.warning(
|
||||
"Overwriting environment variable %s " "from '%s' to '%s'",
|
||||
k,
|
||||
os.environ[k],
|
||||
v,
|
||||
)
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
def producer(
|
||||
batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
lib.cudaSetDevice(i)
|
||||
pointer = lib.cudaMalloc(1024)
|
||||
lib.cudaMemset(pointer, 1, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
producer_queue.put(handle)
|
||||
open_success = consumer_queue.get()
|
||||
if open_success:
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.put(0)
|
||||
consumer_queue.get()
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(
|
||||
batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
lib.cudaSetDevice(j)
|
||||
handle = producer_queue.get()
|
||||
open_success = False
|
||||
try:
|
||||
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||
open_success = True
|
||||
except RuntimeError:
|
||||
# cannot error out here, because the producer process
|
||||
# is still waiting for the response.
|
||||
pass
|
||||
consumer_queue.put(open_success)
|
||||
if open_success:
|
||||
# modify the memory
|
||||
lib.cudaMemset(pointer, 2, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.get()
|
||||
consumer_queue.put(0)
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def can_actually_p2p(
|
||||
batch_src: Sequence[int],
|
||||
batch_tgt: Sequence[int],
|
||||
) -> Sequence[bool]:
|
||||
"""
|
||||
Usually, checking if P2P access is enabled can be done by
|
||||
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||
returns `True` even if P2P access is not actually possible.
|
||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||
Therefore, we have to perform a real P2P access to check if it is actually
|
||||
possible.
|
||||
|
||||
Note on p2p and cuda IPC:
|
||||
Usually, one process uses one GPU:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|
||||
We need to combine p2p and cuda IPC, so that:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|shared|
|
||||
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||
tensor in process tgt will be reflected in the tensor in process src, because
|
||||
they are the same memory segment.
|
||||
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||
GPU src. That's why we need p2p access.
|
||||
|
||||
The most time-consuming part is the process creation. To avoid creating
|
||||
processes for every pair of GPUs, we use batched testing. We create two
|
||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||
the device after each test (which is not available in PyTorch).
|
||||
""" # noqa
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||
# to make sure they see the same set of GPUs
|
||||
|
||||
# make sure the processes are spawned
|
||||
smp = mp.get_context("spawn")
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(
|
||||
target=producer,
|
||||
args=(
|
||||
batch_src,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_tgt = smp.Process(
|
||||
target=consumer,
|
||||
args=(
|
||||
batch_tgt,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: List[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
if a != b:
|
||||
logger.warning(
|
||||
"Two processes do not agree on the P2P access"
|
||||
" status on %d -> %d, treat as disabled.",
|
||||
src,
|
||||
tgt,
|
||||
)
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
||||
# if we test it every time, it will be very slow, because we need to create
|
||||
# N * N * 2 processes, where N is the world size. This is very slow.
|
||||
# to reduce the time, we use a cache file to store the p2p access status.
|
||||
# the cache file is generated by the master process if it does not exist.
|
||||
# then all the processes can read the cache file to check the p2p access status.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
"""Check if GPU src can access GPU tgt."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = torch.cuda.device_count()
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
# VLLM_CACHE_ROOT -> SGLANG_CACHE_ROOT
|
||||
# "~/.cache/vllm" -> "~/.cache/sglang"
|
||||
SGLANG_CACHE_ROOT = os.path.expanduser("~/.cache/sglang")
|
||||
path = os.path.join(
|
||||
SGLANG_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from sglang.srt.distributed.parallel_state import get_world_group
|
||||
|
||||
if (not is_distributed or get_world_group().local_rank == 0) and (
|
||||
not os.path.exists(path)
|
||||
):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache: Dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
||||
# because the caller might not have `if __name__ == "__main__":`,
|
||||
# in that case we cannot use spawn method in multiprocessing.
|
||||
# However, `can_actually_p2p` requires spawn method.
|
||||
# The fix is, we use `subprocess` to call the function,
|
||||
# where we have `if __name__ == "__main__":` in this file.
|
||||
|
||||
# use a temporary file to store the result
|
||||
# we don't use the output of the subprocess directly,
|
||||
# because the subprocess might produce logging output
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run(
|
||||
[sys.executable, __file__], input=input_bytes, capture_output=True
|
||||
)
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(
|
||||
f"Error happened when batch testing "
|
||||
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
||||
f"{returned.stderr.decode()}"
|
||||
) from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
cache[f"{_i}->{_j}"] = r
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
get_world_group().barrier()
|
||||
logger.info("reading GPU P2P access cache from %s", path)
|
||||
with open(path) as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if _is_hip:
|
||||
try:
|
||||
amdsmi_init()
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
amdsmi_shut_down()
|
||||
else:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
|
||||
if _is_hip:
|
||||
"""
|
||||
query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
|
||||
# type is 2 for XGMI
|
||||
if link_type["hops"] != 1 or link_type["type"] != 2:
|
||||
return False
|
||||
except AmdSmiException as error:
|
||||
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
|
||||
)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError:
|
||||
logger.exception(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
result = can_actually_p2p(batch_src, batch_tgt)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
||||
@@ -0,0 +1,49 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.utils import is_hpu
|
||||
|
||||
if is_hpu():
|
||||
import habana_frameworks.torch as htorch # noqa: F401
|
||||
|
||||
|
||||
class HpuCommunicator:
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not is_hpu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||
# (which is required for tensor parallel HPUGraph inference)
|
||||
htorch.core.mark_step()
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += x.dim()
|
||||
input_size = x.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
(world_size,) + input_size, dtype=x.dtype, device=x.device
|
||||
)
|
||||
# All-gather.
|
||||
htorch.core.mark_step()
|
||||
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.utils import is_npu
|
||||
|
||||
|
||||
class NpuCommunicator:
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not is_npu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += x.dim()
|
||||
input_size = x.size()
|
||||
output_size = (input_size[0] * world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
315
python/sglang/srt/distributed/device_communicators/pymscclpp.py
Normal file
315
python/sglang/srt/distributed/device_communicators/pymscclpp.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import bisect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
mscclpp_is_available = False
|
||||
if _is_hip:
|
||||
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
|
||||
mscclpp_is_available = False
|
||||
if _is_cuda:
|
||||
try:
|
||||
import sgl_kernel
|
||||
|
||||
mscclpp_is_available = True
|
||||
except:
|
||||
mscclpp_is_available = False
|
||||
|
||||
|
||||
class MscclContextSelection(IntEnum):
|
||||
MSCCL1SHOT1NODELL = 1
|
||||
MSCCL1SHOT2NODELL = 2
|
||||
|
||||
|
||||
def mscclpp_is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
def mscclpp_convert_to_bytes(size_str):
|
||||
"""
|
||||
Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB")
|
||||
into the equivalent number of bytes using binary units.
|
||||
|
||||
Args:
|
||||
size_str (str): A string representing size with unit (KB, MB, GB).
|
||||
|
||||
Returns:
|
||||
int: Number of bytes.
|
||||
"""
|
||||
size_str = size_str.strip().lower()
|
||||
|
||||
if not size_str:
|
||||
raise ValueError("Empty input string")
|
||||
|
||||
# Extract numeric part and unit
|
||||
for i in range(len(size_str)):
|
||||
if not size_str[i].isdigit() and size_str[i] != ".":
|
||||
break
|
||||
num_str = size_str[:i]
|
||||
unit = size_str[i:].strip()
|
||||
|
||||
try:
|
||||
num = float(num_str)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid numeric value in '{size_str}'")
|
||||
|
||||
# Conversion factors
|
||||
if unit == "b":
|
||||
return int(num)
|
||||
elif unit == "kb":
|
||||
return int(num * 1024)
|
||||
elif unit == "mb":
|
||||
return int(num * 1024 * 1024)
|
||||
elif unit == "gb":
|
||||
return int(num * 1024 * 1024 * 1024)
|
||||
else:
|
||||
raise ValueError(f"Unsupported unit: {unit}, support B, KB, MB, GB only")
|
||||
|
||||
|
||||
def mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2):
|
||||
# warmup
|
||||
for _ in range(warmup_niter):
|
||||
func()
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
torch.cuda.synchronize()
|
||||
dist.barrier()
|
||||
start_event.record()
|
||||
for _ in range(test_niter):
|
||||
func()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
func_cost_us = start_event.elapsed_time(end_event) / test_niter * 1000
|
||||
return func_cost_us
|
||||
|
||||
|
||||
class PyMscclppCommunicator:
|
||||
_SUPPORTED_WORLD_SIZES = [8, 16]
|
||||
_MAX_BYTES = mscclpp_convert_to_bytes(os.getenv("SGLANG_MSCCLPP_MAX_BYTES", "1MB"))
|
||||
_SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16]
|
||||
|
||||
# max_bytes: max supported mscclpp allreduce size
|
||||
# in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB
|
||||
def __init__(
|
||||
self,
|
||||
group: ProcessGroup,
|
||||
device: Union[int, str, torch.device],
|
||||
max_bytes=_MAX_BYTES,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if not mscclpp_is_available:
|
||||
# disable because of missing mscclpp library
|
||||
# e.g. in a non-cuda environment
|
||||
return
|
||||
|
||||
self.group = group
|
||||
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "CustomAllreduce should be attached to a non-NCCL group."
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize mscclpp for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"PyMscclpp is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_mscclpp=True explicitly.",
|
||||
world_size,
|
||||
str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES),
|
||||
)
|
||||
return
|
||||
|
||||
self.ranks = torch.distributed.get_process_group_ranks(group)
|
||||
self.nranks_per_node = torch.cuda.device_count()
|
||||
# for now mscclpp with stride in the communicator is not tested
|
||||
if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1):
|
||||
logger.warning(
|
||||
"PyMscclpp is disabled due to an unsupported group %s."
|
||||
"Please ensure all ranks in the group are consecutive."
|
||||
"To silence this warning, specify disable_mscclpp=True explicitly.",
|
||||
str(self.ranks),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
self.max_bytes = max_bytes
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
if dist.get_rank(group) == 0:
|
||||
unique_id = [ops.mscclpp_generate_unique_id()]
|
||||
else:
|
||||
unique_id = [None]
|
||||
dist.broadcast_object_list(unique_id, src=self.ranks[0], group=self.group)
|
||||
self.unique_id = unique_id[0]
|
||||
self.rank_to_node, self.rank_to_ib = list(range(world_size)), list(
|
||||
range(world_size)
|
||||
)
|
||||
for r in range(world_size):
|
||||
self.rank_to_node[r] = r // 8
|
||||
self.rank_to_ib[r] = self.rank % 8
|
||||
|
||||
self._context = None
|
||||
self.context_selection = None
|
||||
self.msg_size_for_finetune = [
|
||||
2**i for i in range(10, math.floor(math.log2(self.max_bytes)) + 1)
|
||||
]
|
||||
self.msg_size2best_config = {}
|
||||
if world_size == 8:
|
||||
self.context_selection = MscclContextSelection.MSCCL1SHOT1NODELL
|
||||
elif world_size == 16:
|
||||
self.context_selection = MscclContextSelection.MSCCL1SHOT2NODELL
|
||||
if not _is_hip:
|
||||
self.scratch = torch.empty(
|
||||
self.max_bytes * 8,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
self.put_buffer = torch.empty(
|
||||
self.max_bytes * 8 // self.nranks_per_node,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
self._context = ops.mscclpp_init_context(
|
||||
self.unique_id,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.scratch,
|
||||
self.put_buffer,
|
||||
self.nranks_per_node,
|
||||
self.rank_to_node,
|
||||
self.rank_to_ib,
|
||||
int(self.context_selection),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("HIP Mscclpp is not supported yet.")
|
||||
|
||||
self.msg_size2best_config = {}
|
||||
self.pre_tune_config()
|
||||
if dist.get_rank(group) == 0:
|
||||
msg_size2best_config = [self.msg_size2best_config]
|
||||
else:
|
||||
msg_size2best_config = [None]
|
||||
dist.broadcast_object_list(
|
||||
msg_size2best_config, src=self.ranks[0], group=self.group
|
||||
)
|
||||
self.msg_size2best_config = msg_size2best_config[0]
|
||||
|
||||
# PyMscclpp is enabled only in cuda graph
|
||||
self.disabled = True
|
||||
|
||||
def pre_tune_config(self, dtype=torch.bfloat16) -> bool:
|
||||
logger.debug(f"start to pre-tune configs for rank {self.rank}")
|
||||
nthreads_to_try = [256, 512, 1024]
|
||||
nblocks_to_try = [21, 42, 84]
|
||||
inp_randn = torch.ones(
|
||||
self.msg_size_for_finetune[-1] // dtype.itemsize, dtype=dtype, device="cuda"
|
||||
)
|
||||
oup_randn = torch.empty_like(inp_randn)
|
||||
for msg_size in self.msg_size_for_finetune:
|
||||
mock_inp, mock_outp = (
|
||||
inp_randn[: msg_size // dtype.itemsize],
|
||||
oup_randn[: msg_size // dtype.itemsize],
|
||||
)
|
||||
best_config, best_time = None, None
|
||||
for nthreads in nthreads_to_try:
|
||||
for nblocks in nblocks_to_try:
|
||||
cur_cost = mscclpp_bench_time(
|
||||
lambda: ops.mscclpp_allreduce(
|
||||
self._context, mock_inp, mock_outp, nthreads, nblocks
|
||||
)
|
||||
)
|
||||
if best_time is None or cur_cost < best_time:
|
||||
best_config = (nthreads, nblocks)
|
||||
best_time = cur_cost
|
||||
self.msg_size2best_config[msg_size] = best_config
|
||||
if self.rank == 0:
|
||||
logger.debug(
|
||||
f"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us"
|
||||
)
|
||||
|
||||
def should_mscclpp_allreduce(
|
||||
self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM
|
||||
) -> bool:
|
||||
if self.disabled or self._context is None:
|
||||
return False
|
||||
if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE:
|
||||
return False
|
||||
if not mscclpp_is_weak_contiguous(inp):
|
||||
return False
|
||||
# only support sum op
|
||||
if op != ReduceOp.SUM:
|
||||
return False
|
||||
if inp.numel() * inp.element_size() > self.max_bytes:
|
||||
return False
|
||||
return True
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
self.graph_input_set.add((tensor.dtype, tensor.numel()))
|
||||
msg_size = tensor.numel() * tensor.itemsize
|
||||
index = bisect.bisect_left(self.msg_size_for_finetune, msg_size)
|
||||
msg_size_finetune = self.msg_size_for_finetune[index]
|
||||
nthreads, nblocks = self.msg_size2best_config[msg_size_finetune]
|
||||
result = torch.empty_like(tensor)
|
||||
ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks)
|
||||
return result
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self,
|
||||
enable: Optional[bool] = None,
|
||||
):
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
old_disable = self.disabled
|
||||
self.disabled = not enable
|
||||
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
341
python/sglang/srt/distributed/device_communicators/pynccl.py
Normal file
341
python/sglang/srt/distributed/device_communicators/pynccl.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from sglang.srt.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
ncclComm_t,
|
||||
ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum,
|
||||
ncclUniqueId,
|
||||
)
|
||||
from sglang.srt.distributed.utils import StatelessProcessGroup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PyNcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||
device: Union[int, str, torch.device],
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
assert dist.is_initialized()
|
||||
assert (
|
||||
dist.get_backend(group) != dist.Backend.NCCL
|
||||
), "PyNcclCommunicator should be attached to a non-NCCL group."
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
else:
|
||||
self.rank = group.rank
|
||||
self.world_size = group.world_size
|
||||
|
||||
self.group = group
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||
if self.rank == 0:
|
||||
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
if self.rank == 0:
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
else:
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
|
||||
if not isinstance(group, StatelessProcessGroup):
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
else:
|
||||
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
# nccl communicator and stream will use this device
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank
|
||||
)
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
self.stream.synchronize()
|
||||
del data
|
||||
|
||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||
# to use it, use under `with obj.change_state(enable=True)`, usually
|
||||
# when we are using CUDA graph.
|
||||
self.disabled = True
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclAllReduce(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def all_gather(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
dst_slice = output_tensor[split_offset : split_offset + split_size]
|
||||
self.nccl.ncclBroadcast(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(dst_slice.data_ptr()),
|
||||
dst_slice.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
else:
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
chunk = input_tensor[split_offset : split_offset + split_size, ...]
|
||||
|
||||
self.nccl.ncclReduce(
|
||||
buffer_type(chunk.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
chunk.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
else:
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
self.nccl.ncclBroadcast(
|
||||
sendbuff,
|
||||
recvbuff,
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def register_comm_window_raw(self, ptr: int, size: int):
|
||||
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
def group_start(self):
|
||||
self.nccl.ncclGroupStart()
|
||||
|
||||
def group_end(self):
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||
):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
"""
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
old_disable = self.disabled
|
||||
old_stream = self.stream
|
||||
|
||||
self.stream = stream
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
self.stream = old_stream
|
||||
@@ -0,0 +1,133 @@
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.cuda.memory import CUDAPluggableAllocator
|
||||
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
nccl_allocator_source = """
|
||||
#include <nccl.h>
|
||||
extern "C" {
|
||||
|
||||
void* nccl_alloc_plug(size_t size, int device, void* stream) {
|
||||
void* ptr;
|
||||
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
||||
return ptr;
|
||||
|
||||
}
|
||||
|
||||
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
|
||||
ncclResult_t err = ncclMemFree(ptr);
|
||||
}
|
||||
|
||||
}
|
||||
"""
|
||||
|
||||
_allocator = None
|
||||
_mem_pool = None
|
||||
_registered_base_addrs = set()
|
||||
_graph_pool_id = None
|
||||
|
||||
|
||||
def is_symmetric_memory_enabled():
|
||||
return global_server_args_dict["enable_symm_mem"]
|
||||
|
||||
|
||||
def set_graph_pool_id(graph_pool_id):
|
||||
global _graph_pool_id
|
||||
_graph_pool_id = graph_pool_id
|
||||
|
||||
|
||||
def get_nccl_mem_pool():
|
||||
global _allocator, _mem_pool
|
||||
if _mem_pool is None:
|
||||
out_dir = tempfile.gettempdir()
|
||||
nccl_allocator_libname = "nccl_allocator"
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
name=nccl_allocator_libname,
|
||||
cpp_sources=nccl_allocator_source,
|
||||
with_cuda=True,
|
||||
extra_ldflags=["-lnccl"],
|
||||
verbose=True,
|
||||
is_python_module=False,
|
||||
build_directory=out_dir,
|
||||
)
|
||||
_allocator = CUDAPluggableAllocator(
|
||||
f"{out_dir}/{nccl_allocator_libname}.so",
|
||||
"nccl_alloc_plug",
|
||||
"nccl_free_plug",
|
||||
).allocator()
|
||||
_mem_pool = torch.cuda.MemPool(_allocator)
|
||||
return _mem_pool
|
||||
|
||||
|
||||
class use_symmetric_memory:
|
||||
def __init__(self, group_coordinator: GroupCoordinator):
|
||||
if not is_symmetric_memory_enabled():
|
||||
self.group_coordinator = None
|
||||
self._mem_pool_ctx = None
|
||||
self.is_graph_capture = None
|
||||
self.device = None
|
||||
self.pre_2_8_0 = None
|
||||
else:
|
||||
self.group_coordinator = group_coordinator
|
||||
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
|
||||
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
|
||||
self.device = torch.cuda.current_device()
|
||||
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
|
||||
|
||||
def __enter__(self):
|
||||
if not is_symmetric_memory_enabled():
|
||||
return self
|
||||
assert (
|
||||
self.group_coordinator.pynccl_comm is not None
|
||||
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
|
||||
assert (
|
||||
self.group_coordinator.pynccl_comm.nccl_version >= 22703
|
||||
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
||||
if self.is_graph_capture:
|
||||
assert (
|
||||
_graph_pool_id is not None
|
||||
), "graph_pool_id is not set under graph capture"
|
||||
# Pause graph memory pool to use symmetric memory with cuda graph
|
||||
if self.pre_2_8_0:
|
||||
torch._C._cuda_endAllocateCurrentStreamToPool(
|
||||
self.device, _graph_pool_id
|
||||
)
|
||||
else:
|
||||
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
||||
self._mem_pool_ctx.__enter__()
|
||||
return self
|
||||
|
||||
def tag(self, tensor: torch.Tensor):
|
||||
if not is_symmetric_memory_enabled():
|
||||
return
|
||||
tensor.symmetric_memory = True
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if not is_symmetric_memory_enabled():
|
||||
return
|
||||
global _registered_base_addrs
|
||||
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
|
||||
for segment in get_nccl_mem_pool().snapshot():
|
||||
if segment["address"] not in _registered_base_addrs:
|
||||
if segment["stream"] == 0 and self.pre_2_8_0:
|
||||
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
|
||||
# See https://github.com/pytorch/pytorch/issues/152861
|
||||
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
|
||||
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
|
||||
continue
|
||||
self.group_coordinator.pynccl_comm.register_comm_window_raw(
|
||||
segment["address"], segment["total_size"]
|
||||
)
|
||||
_registered_base_addrs.add(segment["address"])
|
||||
|
||||
if self.is_graph_capture:
|
||||
if self.pre_2_8_0:
|
||||
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
|
||||
else:
|
||||
torch._C._cuda_beginAllocateCurrentThreadToPool(
|
||||
self.device, _graph_pool_id
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user