diff --git a/python/pyproject.toml b/python/pyproject.toml
index d6480ebb6..2a4bd774c 100755
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -57,7 +57,7 @@ dependencies = [
"uvicorn",
"uvloop",
"xgrammar==0.1.24",
- "sgl-kernel==0.3.13",
+ "sgl-kernel==0.3.11",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
@@ -67,7 +67,7 @@ dependencies = [
"tiktoken",
"anthropic>=0.20.0",
"torch_memory_saver==0.0.8",
- "nvidia-cutlass-dsl==4.2.1",
+ "nvidia-cutlass-dsl==4.2.0",
]
[project.optional-dependencies]
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
"srt/layers/quantization/configs/*.json",
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
- "srt/speculative/cpp_ngram/*.cpp",
- "srt/speculative/cpp_ngram/*.h",
+ "srt/speculative/cpp_lookahead/*.cpp",
+ "srt/speculative/cpp_lookahead/*.h",
]
[tool.setuptools.packages.find]
diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml
index 6446dcd78..68960d0eb 100755
--- a/python/pyproject_other.toml
+++ b/python/pyproject_other.toml
@@ -65,23 +65,23 @@ tracing = [
srt = [
"sglang[runtime_common]",
- "sgl-kernel==0.3.13",
+ "sgl-kernel==0.3.11",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
- "flashinfer_python==0.4.0rc1",
+ "flashinfer_python==0.3.1",
]
blackwell = [
"sglang[runtime_common]",
- "sgl-kernel==0.3.13",
+ "sgl-kernel==0.3.11",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
- "flashinfer_python==0.4.0rc1",
- "nvidia-cutlass-dsl==4.2.1",
+ "flashinfer_python==0.3.1",
+ "nvidia-cutlass-dsl==4.2.0",
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py
index 92f6e20d1..ebd461ec3 100644
--- a/python/sglang/bench_one_batch.py
+++ b/python/sglang/bench_one_batch.py
@@ -443,9 +443,11 @@ def latency_test_run_once(
if profile:
profiler.stop()
- trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
- _save_profile_trace_results(profiler, trace_filename)
- rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
+ profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
+ _save_profile_trace_results(profiler, profile_filename)
+ rank_print(
+ f"torch profiler chrome trace for prefill saved to {profile_filename}"
+ )
# Decode
decode_latencies = []
@@ -477,10 +479,10 @@ def latency_test_run_once(
if profile and i == output_len / 2:
profiler.stop()
- trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
- _save_profile_trace_results(profiler, trace_filename)
+ profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
+ _save_profile_trace_results(profiler, profile_filename)
rank_print(
- f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
+ f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
)
# Record decode timing from 2nd output
diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py
index 711236b3c..ce904f967 100644
--- a/python/sglang/bench_one_batch_server.py
+++ b/python/sglang/bench_one_batch_server.py
@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
-python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
"""
import argparse
@@ -20,17 +19,12 @@ import multiprocessing
import os
import random
import time
-from typing import List, Optional, Tuple
+from typing import List, Tuple
import numpy as np
import requests
-from pydantic import BaseModel
-from sglang.bench_serving import (
- get_tokenizer,
- sample_mmmu_requests,
- sample_random_requests,
-)
+from sglang.bench_serving import get_tokenizer, sample_random_requests
from sglang.profiler import run_profile
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary
-class ProfileLinks(BaseModel):
- """Pydantic model for profile trace links."""
-
- extend: Optional[str] = None
- decode: Optional[str] = None
-
-
-class BenchmarkResult(BaseModel):
- """Pydantic model for benchmark results table data, for a single isl and osl"""
-
- model_path: str
- run_name: str
- batch_size: int
- input_len: int
- output_len: int
- latency: float
- ttft: float
- input_throughput: float
- output_throughput: float
- overall_throughput: float
- last_gen_throughput: float
- acc_length: Optional[float] = None
- profile_links: Optional[ProfileLinks] = None
-
- @staticmethod
- def help_str() -> str:
- return f"""
-Note: To view the traces through perfetto-ui, please:
- 1. open with Google Chrome
- 2. allow popup
-"""
-
- def to_markdown_row(
- self, trace_dir, base_url: str = "", relay_base: str = ""
- ) -> str:
- """Convert this benchmark result to a markdown table row."""
- # Calculate costs (assuming H100 pricing for now)
- hourly_cost_per_gpu = 2 # $2/hour for one H100
- hourly_cost = hourly_cost_per_gpu * 1 # Assuming tp_size = 1 for simplicity
- input_util = 0.7
- accept_length = (
- round(self.acc_length, 2) if self.acc_length is not None else "n/a"
- )
- itl = 1 / (self.output_throughput / self.batch_size) * 1000
- input_cost = 1e6 / (self.input_throughput * input_util) / 3600 * hourly_cost
- output_cost = 1e6 / self.output_throughput / 3600 * hourly_cost
-
- def get_perfetto_relay_link_from_trace_file(trace_file: str):
- import os
- from urllib.parse import quote
-
- rel_path = os.path.relpath(trace_file, trace_dir)
- raw_file_link = f"{base_url}/{rel_path}"
- relay_link = (
- f"{relay_base}?src={quote(raw_file_link, safe='')}"
- if relay_base and quote
- else raw_file_link
- )
- return relay_link
-
- # Handle profile links
- profile_link = "NA | NA"
- if self.profile_links:
- if self.profile_links.extend or self.profile_links.decode:
- # Create a combined link or use the first available one
- trace_files = [self.profile_links.extend, self.profile_links.decode]
- trace_files_relay_links = [
- f"[trace]({get_perfetto_relay_link_from_trace_file(trace_file)})"
- for trace_file in trace_files
- ]
-
- profile_link = " | ".join(trace_files_relay_links)
-
- # Build the row
- return f"| {self.batch_size} | {self.input_len} | {self.latency:.2f} | {self.input_throughput:.2f} | {self.output_throughput:.2f} | {accept_length} | {itl:.2f} | {input_cost:.2f} | {output_cost:.2f} | {profile_link} |\n"
-
- @classmethod
- def generate_markdown_report(
- cls, trace_dir, results: List["BenchmarkResult"]
- ) -> str:
- """Generate a markdown report from a list of BenchmarkResult object from a single run."""
- import os
-
- summary = f"### {results[0].model_path}\n"
-
- # summary += (
- # f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
- # )
- summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|\n"
- summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |\n"
-
- # all results should share the same isl & osl
- for result in results:
- base_url = os.getenv("TRACE_BASE_URL", "").rstrip("/")
- relay_base = os.getenv("PERFETTO_RELAY_URL", "").rstrip("/")
- relay_base = "https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
- # base_url = "https://github.com/sgl-project/ci-data/traces"
- summary += result.to_markdown_row(trace_dir, base_url, relay_base)
-
- return summary
-
-
@dataclasses.dataclass
class BenchArgs:
run_name: str = "default"
@@ -158,12 +50,8 @@ class BenchArgs:
profile: bool = False
profile_steps: int = 3
profile_by_stage: bool = False
- profile_filename_prefix: str = None
- append_to_github_summary: bool = True
dataset_path: str = ""
parallel_batch: bool = False
- dataset_name: str = "random"
- output_path: Optional[str] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
@@ -179,13 +67,6 @@ class BenchArgs:
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
- parser.add_argument(
- "--dataset-name",
- type=str,
- default=BenchArgs.dataset_name,
- choices=["mmmu", "random"],
- help="Name of the dataset to benchmark on.",
- )
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--client-stream-interval",
@@ -215,36 +96,14 @@ class BenchArgs:
help="Path to the dataset.",
)
parser.add_argument("--parallel-batch", action="store_true")
- parser.add_argument(
- "--profile-filename-prefix",
- type=str,
- default=BenchArgs.profile_filename_prefix,
- )
- parser.add_argument(
- "--no-append-to-github-summary",
- action="store_false",
- dest="append_to_github_summary",
- help="Disable appending the output of this run to github ci summary",
- )
- parser.add_argument(
- "--output-path",
- type=str,
- default=BenchArgs.output_path,
- help="Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename.",
- )
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
- kwargs = {}
- for attr, attr_type in attrs:
- val = getattr(args, attr)
- if attr_type is type(None):
- kwargs[attr] = val
- else:
- kwargs[attr] = attr_type(val)
- return cls(**kwargs)
+ return cls(
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
+ )
def launch_server_internal(server_args):
@@ -289,35 +148,23 @@ def run_one_case(
run_name: str,
result_filename: str,
tokenizer,
- dataset_name="",
profile: bool = False,
profile_steps: int = 3,
profile_by_stage: bool = False,
- profile_filename_prefix: str = None,
dataset_path: str = "",
parallel_batch: bool = False,
):
requests.post(url + "/flush_cache")
- # TODO: reuse bench_serving.get_dataset ?
- if dataset_name == "mmmu":
- input_requests = sample_mmmu_requests(
- num_requests=batch_size,
- tokenizer=tokenizer,
- fixed_output_len=output_len,
- apply_chat_template=True,
- random_sample=False,
- )
- elif dataset_name == "random":
- input_requests = sample_random_requests(
- input_len=input_len,
- output_len=output_len,
- num_prompts=batch_size,
- range_ratio=1.0,
- tokenizer=tokenizer,
- dataset_path=dataset_path,
- random_sample=True,
- return_text=False,
- )
+ input_requests = sample_random_requests(
+ input_len=input_len,
+ output_len=output_len,
+ num_prompts=batch_size,
+ range_ratio=1.0,
+ tokenizer=tokenizer,
+ dataset_path=dataset_path,
+ random_sample=True,
+ return_text=False,
+ )
use_structured_outputs = False
if use_structured_outputs:
@@ -334,48 +181,26 @@ def run_one_case(
profile_link = None
if profile:
- output_dir, profile_name = None, None
- if profile_filename_prefix:
- output_dir = os.path.dirname(profile_filename_prefix)
- profile_name = os.path.basename(profile_filename_prefix)
profile_link: str = run_profile(
- url,
- profile_steps,
- ["CPU", "GPU"],
- output_dir,
- profile_name,
- profile_by_stage,
+ url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
)
tic = time.perf_counter()
-
- payload = {
- "sampling_params": {
- "temperature": temperature,
- "max_new_tokens": output_len,
- "ignore_eos": True,
- "json_schema": json_schema,
- "stream_interval": stream_interval,
- },
- "return_logprob": return_logprob,
- "stream": True,
- **({"parallel_batch": parallel_batch} if parallel_batch else {}),
- }
- if dataset_name == "mmmu":
- # vlm
- input_ids = []
- for input_req in input_requests:
- input_ids += [tokenizer.encode(input_req.prompt)]
- payload["image_data"] = [req.image_data for req in input_requests]
-
- else:
- input_ids = [req.prompt for req in input_requests]
-
- payload["input_ids"] = input_ids
-
response = requests.post(
url + "/generate",
- json=payload,
+ json={
+ "input_ids": [req.prompt for req in input_requests],
+ "sampling_params": {
+ "temperature": temperature,
+ "max_new_tokens": output_len,
+ "ignore_eos": True,
+ "json_schema": json_schema,
+ "stream_interval": stream_interval,
+ },
+ "return_logprob": return_logprob,
+ "stream": True,
+ **({"parallel_batch": parallel_batch} if parallel_batch else {}),
+ },
stream=True,
)
@@ -439,100 +264,10 @@ def run_one_case(
overall_throughput,
last_gen_throughput,
acc_length,
- profile_link,
+ profile_link if profile else None,
)
-def save_results_as_json(result: List[Tuple], bench_args: BenchArgs, model: str):
- """Save benchmark results as JSON using Pydantic models."""
- json_results = []
-
- # Generate all parameter combinations to match with results
- param_combinations = list(
- itertools.product(
- bench_args.batch_size, bench_args.input_len, bench_args.output_len
- )
- )
-
- for i, (
- batch_size,
- latency,
- ttft,
- input_throughput,
- output_throughput,
- overall_throughput,
- last_gen_throughput,
- acc_length,
- profile_link,
- ) in enumerate(result):
- # Get the corresponding parameters for this result
- bs, input_len, output_len = param_combinations[i]
-
- # Parse profile links if available
- profile_links = None
- if profile_link:
- profile_links = parse_profile_links(
- profile_link, batch_size, input_len, output_len
- )
-
- benchmark_result = BenchmarkResult(
- model_path=model,
- run_name=bench_args.run_name,
- batch_size=batch_size,
- input_len=input_len,
- output_len=output_len,
- latency=latency,
- ttft=ttft,
- input_throughput=input_throughput,
- output_throughput=output_throughput,
- overall_throughput=overall_throughput,
- last_gen_throughput=last_gen_throughput,
- acc_length=acc_length,
- profile_links=profile_links,
- )
- json_results.append(benchmark_result.model_dump())
-
- # Save to JSON file
- with open(bench_args.output_path, "w", encoding="utf-8") as f:
- json.dump(json_results, f, indent=2, ensure_ascii=False)
-
- print(f"Results saved as JSON to {bench_args.output_path}")
-
-
-def parse_profile_links(
- profile_dir: str, batch_size: int, input_len: int, output_len: int
-) -> Optional[ProfileLinks]:
- """Parse profile directory to extract extend and decode trace file links."""
- if not profile_dir or not os.path.exists(profile_dir):
- return None
-
- extend_link = None
- decode_link = None
-
- # Look for extend/prefill trace files
- for file in os.listdir(profile_dir):
- if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
- if "extend" in file.lower() or "prefill" in file.lower():
- extend_link = os.path.join(profile_dir, file)
- elif "decode" in file.lower():
- decode_link = os.path.join(profile_dir, file)
-
- # If no specific extend/decode files found, try to find files with batch/input/output info
- if not extend_link or not decode_link:
- for file in os.listdir(profile_dir):
- if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
- if f"_batch{batch_size}_input{input_len}_output{output_len}_" in file:
- if "prefill" in file.lower() or "extend" in file.lower():
- extend_link = os.path.join(profile_dir, file)
- elif "decode" in file.lower():
- decode_link = os.path.join(profile_dir, file)
-
- if extend_link or decode_link:
- return ProfileLinks(extend=extend_link, decode=decode_link)
-
- return None
-
-
def get_report_summary(
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
):
@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
- dataset_name=bench_args.dataset_name,
run_name="",
result_filename="",
tokenizer=tokenizer,
@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
- dataset_name=bench_args.dataset_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
- profile_filename_prefix=bench_args.profile_filename_prefix,
)
)
@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
- dataset_name=bench_args.dataset_name,
profile=bench_args.profile,
profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
- profile_filename_prefix=bench_args.profile_filename_prefix,
)[-1],
)
)
@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print(f"\nResults are saved to {bench_args.result_filename}")
- # Save results as JSON if output_path is specified
- if bench_args.output_path:
- save_results_as_json(result, bench_args, model=server_args.model_path)
-
if not bench_args.show_report:
return
summary = get_report_summary(result, server_args, bench_args)
+ print(summary)
- if is_in_ci() and bench_args.append_to_github_summary:
+ if is_in_ci():
write_github_step_summary(summary)
diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py
index 3f515a1e9..ea670d97f 100644
--- a/python/sglang/bench_serving.py
+++ b/python/sglang/bench_serving.py
@@ -208,10 +208,6 @@ async def async_request_openai_completions(
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
}
-
- if request_func_input.image_data:
- payload.update({"image_data": request_func_input.image_data})
-
headers = get_auth_headers()
output = RequestFuncOutput.init_new(request_func_input)
@@ -1763,9 +1759,7 @@ async def benchmark(
pbar.close()
if "sglang" in backend:
- server_info = requests.get(
- base_url + "/get_server_info", headers=get_auth_headers()
- )
+ server_info = requests.get(base_url + "/get_server_info")
if server_info.status_code == 200:
server_info_json = server_info.json()
if "decode" in server_info_json:
diff --git a/python/sglang/srt/environ.py b/python/sglang/environ.py
similarity index 98%
rename from python/sglang/srt/environ.py
rename to python/sglang/environ.py
index de0d52742..e28120702 100644
--- a/python/sglang/srt/environ.py
+++ b/python/sglang/environ.py
@@ -124,8 +124,6 @@ class Envs:
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
- SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
- SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
# Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py
index 383bafe0d..f006bd94c 100644
--- a/python/sglang/global_config.py
+++ b/python/sglang/global_config.py
@@ -37,8 +37,8 @@ class GlobalConfig:
)
# Runtime constants: others
self.retract_decode_steps = 20
- self.flashinfer_workspace_size = int(
- os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
+ self.flashinfer_workspace_size = os.environ.get(
+ "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
)
# Output tokenization configs
diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py
index 90a9761cf..caae7b0f6 100644
--- a/python/sglang/launch_server.py
+++ b/python/sglang/launch_server.py
@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
-MOVE_ENVS_WARN = """
-########################################################################
-# For contributors and developers: #
-# Please move environment variable definitions to sglang.srt.environ #
-# using the following pattern: #
-# SGLANG_XXX = EnvBool(False) #
-# #
-########################################################################
-"""
-
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
- from sglang.srt.server_args import print_deprecated_warning
-
- print_deprecated_warning(MOVE_ENVS_WARN)
-
try:
launch_server(server_args)
finally:
diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py
index fb8be846b..6ac003ea4 100644
--- a/python/sglang/srt/configs/load_config.py
+++ b/python/sglang/srt/configs/load_config.py
@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
JAX = "jax"
REMOTE = "remote"
REMOTE_INSTANCE = "remote_instance"
- RDMA = "rdma"
- LOCAL_CACHED = "local_cached"
@dataclass
@@ -49,7 +47,6 @@ class LoadConfig:
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
- decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
@@ -57,11 +54,6 @@ class LoadConfig:
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
- decrypt_max_concurrency: int = -1
- tp_rank: Optional[int] = None
- remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
- remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
- remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py
index 92d0e130f..69d99f906 100644
--- a/python/sglang/srt/configs/model_config.py
+++ b/python/sglang/srt/configs/model_config.py
@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
-from sglang.srt.utils import get_bool_env_var, is_hip, retry
+from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
TRANSFORMERS = "transformers"
+def is_deepseek_nsa(config: PretrainedConfig) -> bool:
+ return (
+ config.architectures is not None
+ and config.architectures[0]
+ in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
+ and getattr(config, "index_topk", None) is not None
+ )
+
+
+def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
+ assert is_deepseek_nsa(config)
+ return config.index_head_dim
+
+
+def get_nsa_index_topk(config: PretrainedConfig) -> int:
+ assert is_deepseek_nsa(config)
+ return config.index_topk
+
+
+def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
+ assert is_deepseek_nsa(config)
+ return config.index_n_heads
+
+
class ModelConfig:
def __init__(
self,
@@ -64,20 +88,35 @@ class ModelConfig:
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
+ tp_rank: Optional[int] = None,
+ remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
+ remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
+ remote_instance_weight_loader_send_weights_group_ports: Optional[
+ List[int]
+ ] = None,
) -> None:
# Parse args
self.model_path = model_path
self.revision = revision
self.quantization = quantization
- self.is_draft_model = is_draft_model
self.model_impl = model_impl
+ self.tp_rank = tp_rank
+ self.remote_instance_weight_loader_seed_instance_ip = (
+ remote_instance_weight_loader_seed_instance_ip
+ )
+ self.remote_instance_weight_loader_seed_instance_service_port = (
+ remote_instance_weight_loader_seed_instance_service_port
+ )
+ self.remote_instance_weight_loader_send_weights_group_ports = (
+ remote_instance_weight_loader_send_weights_group_ports
+ )
- # Get hf config
- self._maybe_pull_model_tokenizer_from_remote()
+ self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
+
self.hf_config = get_config(
self.model_path,
trust_remote_code=trust_remote_code,
@@ -85,7 +124,7 @@ class ModelConfig:
model_override_args=self.model_override_args,
**kwargs,
)
- self.hf_text_config = get_hf_text_config(self.hf_config)
+
self.hf_generation_config = get_generation_config(
self.model_path,
trust_remote_code=trust_remote_code,
@@ -93,25 +132,7 @@ class ModelConfig:
**kwargs,
)
- # Set enable_multimodal
- if enable_multimodal is None:
- mm_disabled_models = [
- "Gemma3ForConditionalGeneration",
- "Llama4ForConditionalGeneration",
- "Step3VLForConditionalGeneration",
- ]
- if self.hf_config.architectures[0] in mm_disabled_models:
- enable_multimodal = False
- logger.info(
- f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
- )
- else:
- enable_multimodal = True
-
- # Config draft model
- self._config_draft_model()
-
- # Check model type
+ self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)
@@ -127,70 +148,20 @@ class ModelConfig:
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
)
)
- self.is_generation = is_generation_model(
- self.hf_config.architectures, is_embedding
- )
- self.is_multimodal = enable_multimodal and is_multimodal_model(
- self.hf_config.architectures
- )
- self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
- self.hf_config.architectures
- )
- self.is_image_gen = enable_multimodal and is_image_gen_model(
- self.hf_config.architectures
- )
- self.is_audio_model = enable_multimodal and is_audio_model(
- self.hf_config.architectures
- )
- self.is_multimodal_chunked_prefill_supported = (
- enable_multimodal
- and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
- )
- self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
- self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
- # Derive context length and model shapes
- self._derive_context_length(context_length)
- self._derive_model_shapes()
-
- # Verify quantization
- self._verify_quantization()
-
- # Verify dual-chunk attention config
- self._verify_dual_chunk_attention_config()
-
- # Cache attributes
- self.hf_eos_token_id = self._get_hf_eos_token_id()
-
- # multimodal
- self.image_token_id = getattr(
- self.hf_config, "image_token_id", None
- ) or getattr(self.hf_config, "image_token_index", None)
-
- @staticmethod
- def from_server_args(
- server_args: ServerArgs,
- model_path: str = None,
- model_revision: str = None,
- **kwargs,
- ):
- return ModelConfig(
- model_path=model_path or server_args.model_path,
- trust_remote_code=server_args.trust_remote_code,
- revision=model_revision or server_args.revision,
- context_length=server_args.context_length,
- model_override_args=server_args.json_model_override_args,
- is_embedding=server_args.is_embedding,
- enable_multimodal=server_args.enable_multimodal,
- dtype=server_args.dtype,
- quantization=server_args.quantization,
- hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
- model_impl=server_args.model_impl,
- **kwargs,
- )
-
- def _config_draft_model(self):
- is_draft_model = self.is_draft_model
+ if enable_multimodal is None:
+ mm_disabled_models = [
+ "Gemma3ForConditionalGeneration",
+ "Llama4ForConditionalGeneration",
+ "Step3VLForConditionalGeneration",
+ ]
+ if self.hf_config.architectures[0] in mm_disabled_models:
+ enable_multimodal = False
+ logger.info(
+ f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
+ )
+ else:
+ enable_multimodal = True
if (
is_draft_model
@@ -225,10 +196,31 @@ class ModelConfig:
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
self.hf_config.num_nextn_predict_layers = 1
- def _derive_context_length(self, context_length: int):
- is_draft_model = self.is_draft_model
- derived_context_len = get_context_length(self.hf_text_config)
+ # Check model type
+ self.is_generation = is_generation_model(
+ self.hf_config.architectures, is_embedding
+ )
+ self.is_multimodal = enable_multimodal and is_multimodal_model(
+ self.hf_config.architectures
+ )
+ self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
+ self.hf_config.architectures
+ )
+ self.is_image_gen = enable_multimodal and is_image_gen_model(
+ self.hf_config.architectures
+ )
+ self.is_audio_model = enable_multimodal and is_audio_model(
+ self.hf_config.architectures
+ )
+ self.is_multimodal_chunked_prefill_supported = (
+ enable_multimodal
+ and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
+ )
+ self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
+ self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
+ # Derive context length
+ derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None:
if context_length > derived_context_len:
reason = "Target model's" if is_draft_model else "User-specified"
@@ -242,11 +234,6 @@ class ModelConfig:
):
logger.warning(msg)
self.context_len = context_length
- if is_draft_model:
- self.hf_text_config.max_position_embeddings = context_length
- logger.warning(
- f"Overriding the draft model's max_position_embeddings to {context_length}."
- )
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
@@ -256,10 +243,6 @@ class ModelConfig:
else:
self.context_len = derived_context_len
- # Transfer context_len to HuggingFace config so models can access it
- self.hf_config.context_len = self.context_len
-
- def _derive_model_shapes(self):
# Unify the config keys for hf_text_config
self.head_dim = getattr(
self.hf_text_config,
@@ -270,6 +253,7 @@ class ModelConfig:
# FIXME: temporary special judge for MLA architecture
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
+ or "DeepseekV32ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures
@@ -282,6 +266,11 @@ class ModelConfig:
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
+ self.index_head_dim = (
+ get_nsa_index_head_dim(self.hf_config)
+ if is_deepseek_nsa(self.hf_config)
+ else None
+ )
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
@@ -354,6 +343,45 @@ class ModelConfig:
)
self.vocab_size = self.hf_text_config.vocab_size
+ # Verify quantization
+ self._verify_quantization()
+
+ # Verify dual-chunk attention config
+ self._verify_dual_chunk_attention_config()
+
+ # Cache attributes
+ self.hf_eos_token_id = self.get_hf_eos_token_id()
+
+ # multimodal
+ self.image_token_id = getattr(
+ self.hf_config, "image_token_id", None
+ ) or getattr(self.hf_config, "image_token_index", None)
+
+ @staticmethod
+ def from_server_args(
+ server_args: ServerArgs,
+ model_path: str = None,
+ model_revision: str = None,
+ **kwargs,
+ ):
+ return ModelConfig(
+ model_path=model_path or server_args.model_path,
+ trust_remote_code=server_args.trust_remote_code,
+ revision=model_revision or server_args.revision,
+ context_length=server_args.context_length,
+ model_override_args=server_args.json_model_override_args,
+ is_embedding=server_args.is_embedding,
+ enable_multimodal=server_args.enable_multimodal,
+ dtype=server_args.dtype,
+ quantization=server_args.quantization,
+ hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
+ model_impl=server_args.model_impl,
+ remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
+ remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
+ remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
+ **kwargs,
+ )
+
def get_total_num_attention_heads(self) -> int:
return self.num_attention_heads
@@ -454,31 +482,13 @@ class ModelConfig:
from huggingface_hub import HfApi
hf_api = HfApi()
-
- def check_hf_quant_config():
- return hf_api.file_exists(
- self.model_path, "hf_quant_config.json"
- )
-
- # Retry HF API call up to 3 times
- file_exists = retry(
- check_hf_quant_config,
- max_retry=2,
- initial_delay=1.0,
- max_delay=5.0,
- )
-
- if file_exists:
+ if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
quant_cfg = modelopt_quant_config
-
except huggingface_hub.errors.OfflineModeIsEnabled:
logger.warning(
"Offline mode is enabled, skipping hf_quant_config.json check"
)
- except Exception as e:
- logger.warning(
- f"Failed to check hf_quant_config.json: {self.model_path} {e}"
- )
+ pass
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_config_file = os.path.join(
@@ -606,7 +616,7 @@ class ModelConfig:
"sparse_attention_enabled"
] = True
- def _get_hf_eos_token_id(self) -> Optional[Set[int]]:
+ def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids is not None:
# it can be either int or list of int
@@ -626,7 +636,7 @@ class ModelConfig:
eos_ids = eos_ids | generation_eos_ids
return eos_ids
- def _maybe_pull_model_tokenizer_from_remote(self) -> None:
+ def maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
directory in case of remote.
@@ -769,8 +779,6 @@ multimodal_model_archs = [
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
- "Qwen3VLForConditionalGeneration",
- "Qwen3VLMoeForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",
diff --git a/python/sglang/srt/configs/qwen3_vl.py b/python/sglang/srt/configs/qwen3_vl.py
deleted file mode 100644
index 4a995c856..000000000
--- a/python/sglang/srt/configs/qwen3_vl.py
+++ /dev/null
@@ -1,586 +0,0 @@
-from typing import Optional, Union
-
-from transformers import PretrainedConfig
-from transformers.modeling_rope_utils import rope_config_validation
-
-
-class Qwen3VLVisionConfig(PretrainedConfig):
- model_type = "qwen3_vl"
- base_config_key = "vision_config"
-
- def __init__(
- self,
- depth=27,
- hidden_size=1152,
- hidden_act="gelu_pytorch_tanh",
- intermediate_size=4304,
- num_heads=16,
- in_channels=3,
- patch_size=16,
- spatial_merge_size=2,
- temporal_patch_size=2,
- out_hidden_size=3584,
- num_position_embeddings=2304,
- deepstack_visual_indexes=[8, 16, 24],
- initializer_range=0.02,
- **kwargs,
- ):
- super().__init__(**kwargs)
-
- self.depth = depth
- self.hidden_size = hidden_size
- self.hidden_act = hidden_act
- self.intermediate_size = intermediate_size
- self.num_heads = num_heads
- self.in_channels = in_channels
- self.patch_size = patch_size
- self.spatial_merge_size = spatial_merge_size
- self.temporal_patch_size = temporal_patch_size
- self.out_hidden_size = out_hidden_size
- self.num_position_embeddings = num_position_embeddings
- self.initializer_range = initializer_range
- self.deepstack_visual_indexes = deepstack_visual_indexes
-
-
-class Qwen3VLTextConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a
- Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of
- Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- vocab_size (`int`, *optional*, defaults to 151936):
- Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`Qwen3VLModel`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 22016):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer encoder.
- num_key_value_heads (`int`, *optional*, defaults to 32):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details, check out [this
- paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
- head_dim (`int`, *optional*, defaults to 128):
- The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 128000):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether the model's input and output word embeddings should be tied.
- rope_theta (`float`, *optional*, defaults to 5000000.0):
- The base period of the RoPE embeddings.
- rope_scaling (`Dict`, *optional*):
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
- accordingly.
- Expected contents:
- `rope_type` (`str`):
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
- 'llama3'], with 'default' being the original RoPE implementation.
- `factor` (`float`, *optional*):
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
- original maximum pre-trained length.
- `original_max_position_embeddings` (`int`, *optional*):
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
- pretraining.
- `attention_factor` (`float`, *optional*):
- Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
- computation. If unspecified, it defaults to value recommended by the implementation, using the
- `factor` field to infer the suggested value.
- `beta_fast` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
- ramp function. If unspecified, it defaults to 32.
- `beta_slow` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
- ramp function. If unspecified, it defaults to 1.
- `short_factor` (`list[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `long_factor` (`list[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `low_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
- `high_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
-
- ```python
- >>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig
-
- >>> # Initializing a Qwen3VL style configuration
- >>> configuration = Qwen3VLTextConfig()
-
- >>> # Initializing a model from the Qwen3-VL-7B style configuration
- >>> model = Qwen3VLTextModel(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "qwen3_vl_text"
- base_config_key = "text_config"
-
- def __init__(
- self,
- vocab_size=151936,
- hidden_size=4096,
- intermediate_size=22016,
- num_hidden_layers=32,
- num_attention_heads=32,
- num_key_value_heads=32,
- head_dim=128,
- hidden_act="silu",
- max_position_embeddings=128000,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- tie_word_embeddings=False,
- rope_theta=5000000.0,
- rope_scaling=None,
- attention_bias=False,
- attention_dropout=0.0,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
-
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
-
- self.num_key_value_heads = num_key_value_heads
- self.head_dim = head_dim
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
-
- rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
-
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
-
-
-class Qwen3VLConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a
- Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of
- Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
-
- Args:
- text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):
- The config object or dictionary of the text backbone.
- vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
- The config object or dictionary of the vision backbone.
- image_token_id (`int`, *optional*, defaults to 151655):
- The image token index to encode the image prompt.
- video_token_id (`int`, *optional*, defaults to 151656):
- The video token index to encode the image prompt.
- vision_start_token_id (`int`, *optional*, defaults to 151652):
- The start token index to encode the image prompt.
- vision_end_token_id (`int`, *optional*, defaults to 151653):
- The end token index to encode the image prompt.
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether to tie the word embeddings.
-
- ```python
- >>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig
-
- >>> # Initializing a Qwen3-VL style configuration
- >>> configuration = Qwen3VLConfig()
-
- >>> # Initializing a model from the Qwen3-VL-4B style configuration
- >>> model = Qwen3VLForConditionalGeneration(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "qwen3_vl"
- sub_configs = {
- "vision_config": Qwen3VLVisionConfig,
- "text_config": Qwen3VLTextConfig,
- }
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- text_config=None,
- vision_config=None,
- image_token_id=151655,
- video_token_id=151656,
- vision_start_token_id=151652,
- vision_end_token_id=151653,
- tie_word_embeddings=False,
- **kwargs,
- ):
- if isinstance(vision_config, dict):
- self.vision_config = self.sub_configs["vision_config"](**vision_config)
- elif vision_config is None:
- self.vision_config = self.sub_configs["vision_config"]()
-
- if isinstance(text_config, dict):
- self.text_config = self.sub_configs["text_config"](**text_config)
- elif text_config is None:
- self.text_config = self.sub_configs["text_config"]()
-
- self.image_token_id = image_token_id
- self.video_token_id = video_token_id
- self.vision_start_token_id = vision_start_token_id
- self.vision_end_token_id = vision_end_token_id
- super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
-
-
-class Qwen3VLMoeTextConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
- Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of
- Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- vocab_size (`int`, *optional*, defaults to 151936):
- Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`Qwen2MoeModel`]
- hidden_size (`int`, *optional*, defaults to 2048):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 5632):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 24):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 16):
- Number of attention heads for each attention layer in the Transformer encoder.
- num_key_value_heads (`int`, *optional*, defaults to 16):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 128000):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether the model's input and output word embeddings should be tied.
- rope_theta (`float`, *optional*, defaults to 5000000.0):
- The base period of the RoPE embeddings.
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- decoder_sparse_step (`int`, *optional*, defaults to 1):
- The frequency of the MoE layer.
- moe_intermediate_size (`int`, *optional*, defaults to 1408):
- Intermediate size of the routed expert.
- num_experts_per_tok (`int`, *optional*, defaults to 4):
- Number of selected experts.
- num_experts (`int`, *optional*, defaults to 60):
- Number of routed experts.
- norm_topk_prob (`bool`, *optional*, defaults to `True`):
- Whether to normalize the topk probabilities.
- mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
- Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
- The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
- If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
- rope_scaling (`Dict`, *optional*):
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
- accordingly.
- Expected contents:
- `rope_type` (`str`):
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
- 'llama3'], with 'default' being the original RoPE implementation.
- `factor` (`float`, *optional*):
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
- original maximum pre-trained length.
- `original_max_position_embeddings` (`int`, *optional*):
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
- pretraining.
- `attention_factor` (`float`, *optional*):
- Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
- computation. If unspecified, it defaults to value recommended by the implementation, using the
- `factor` field to infer the suggested value.
- `beta_fast` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
- ramp function. If unspecified, it defaults to 32.
- `beta_slow` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
- ramp function. If unspecified, it defaults to 1.
- `short_factor` (`List[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `long_factor` (`List[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `low_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
- `high_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
- head_dim (`int`, *optional*):
- The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
-
- ```python
- >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
-
- >>> # Initializing a Qwen3VLMoe style configuration
- >>> configuration = Qwen3VLMoeConfig()
-
- >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
- >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "qwen3_vl_moe_text"
- base_config_key = "text_config"
- keys_to_ignore_at_inference = ["past_key_values"]
- # Default tensor parallel plan for base model `Qwen3VLMoe`
- base_model_tp_plan = {
- "layers.*.self_attn.q_proj": "colwise",
- "layers.*.self_attn.k_proj": "colwise",
- "layers.*.self_attn.v_proj": "colwise",
- "layers.*.self_attn.o_proj": "rowwise",
- "layers.*.mlp.gate_proj": "colwise",
- "layers.*.mlp.up_proj": "colwise",
- "layers.*.mlp.down_proj": "rowwise",
- }
- base_model_pp_plan = {
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
- "norm": (["hidden_states"], ["hidden_states"]),
- }
-
- def __init__(
- self,
- vocab_size=151936,
- hidden_size=2048,
- intermediate_size=5632,
- num_hidden_layers=24,
- num_attention_heads=16,
- num_key_value_heads=16,
- hidden_act="silu",
- max_position_embeddings=128000,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- tie_word_embeddings=False,
- rope_theta=5000000.0,
- attention_bias=False,
- attention_dropout=0.0,
- decoder_sparse_step=1,
- moe_intermediate_size=1408,
- num_experts_per_tok=4,
- num_experts=60,
- norm_topk_prob=True,
- mlp_only_layers=None,
- rope_scaling=None,
- head_dim=None,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
-
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
-
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
- self.rope_scaling = rope_scaling
- self.head_dim = head_dim or hidden_size // num_attention_heads
-
- rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
-
- # MoE arguments
- self.decoder_sparse_step = decoder_sparse_step
- self.moe_intermediate_size = moe_intermediate_size
- self.num_experts_per_tok = num_experts_per_tok
- self.num_experts = num_experts
- self.norm_topk_prob = norm_topk_prob
- self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
-
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
-
-
-class Qwen3VLMoeVisionConfig(PretrainedConfig):
- model_type = "qwen3_vl_moe"
- base_config_key = "vision_config"
-
- def __init__(
- self,
- depth=27,
- hidden_size=1152,
- hidden_act="gelu_pytorch_tanh",
- intermediate_size=4304,
- num_heads=16,
- in_channels=3,
- patch_size=16,
- spatial_merge_size=2,
- temporal_patch_size=2,
- out_hidden_size=3584,
- num_position_embeddings=2304,
- deepstack_visual_indexes=[8, 16, 24],
- initializer_range=0.02,
- **kwargs,
- ):
- super().__init__(**kwargs)
-
- self.depth = depth
- self.hidden_size = hidden_size
- self.hidden_act = hidden_act
- self.intermediate_size = intermediate_size
- self.num_heads = num_heads
- self.in_channels = in_channels
- self.patch_size = patch_size
- self.spatial_merge_size = spatial_merge_size
- self.temporal_patch_size = temporal_patch_size
- self.out_hidden_size = out_hidden_size
- self.num_position_embeddings = num_position_embeddings
- self.initializer_range = initializer_range
- self.deepstack_visual_indexes = deepstack_visual_indexes
-
-
-class Qwen3VLMoeConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
- Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of
- Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
-
- Args:
- text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
- The config object or dictionary of the text backbone.
- vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
- The config object or dictionary of the vision backbone.
- image_token_id (`int`, *optional*, defaults to 151655):
- The image token index to encode the image prompt.
- video_token_id (`int`, *optional*, defaults to 151656):
- The video token index to encode the image prompt.
- vision_start_token_id (`int`, *optional*, defaults to 151652):
- The start token index to encode the image prompt.
- vision_end_token_id (`int`, *optional*, defaults to 151653):
- The end token index to encode the image prompt.
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether to tie the word embeddings.
-
- ```python
- >>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
-
- >>> # Initializing a Qwen3-VL-MOE style configuration
- >>> configuration = Qwen3VLMoeConfig()
-
- >>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
- >>> model = Qwen3VLMoeForConditionalGeneration(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "qwen3_vl_moe"
- sub_configs = {
- "vision_config": Qwen3VLMoeVisionConfig,
- "text_config": Qwen3VLMoeTextConfig,
- }
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- text_config=None,
- vision_config=None,
- image_token_id=151655,
- video_token_id=151656,
- vision_start_token_id=151652,
- vision_end_token_id=151653,
- tie_word_embeddings=False,
- **kwargs,
- ):
- if isinstance(vision_config, dict):
- self.vision_config = self.sub_configs["vision_config"](**vision_config)
- elif vision_config is None:
- self.vision_config = self.sub_configs["vision_config"]()
-
- if isinstance(text_config, dict):
- self.text_config = self.sub_configs["text_config"](**text_config)
- elif text_config is None:
- self.text_config = self.sub_configs["text_config"]()
-
- self.image_token_id = image_token_id
- self.video_token_id = video_token_id
- self.vision_start_token_id = vision_start_token_id
- self.vision_end_token_id = vision_end_token_id
- super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
-
-
-__all__ = [
- "Qwen3VLMoeConfig",
- "Qwen3VLMoeVisionConfig",
- "Qwen3VLConfig",
- "Qwen3VLVisionConfig",
-]
diff --git a/python/sglang/srt/disaggregation/ascend/transfer_engine.py b/python/sglang/srt/disaggregation/ascend/transfer_engine.py
index 0ccffffd6..c87020d39 100644
--- a/python/sglang/srt/disaggregation/ascend/transfer_engine.py
+++ b/python/sglang/srt/disaggregation/ascend/transfer_engine.py
@@ -2,9 +2,19 @@ import logging
import os
from typing import List, Optional
+import torch
+
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
+try:
+ from mf_adapter import TransferEngine
+
+ import_error = None
+except ImportError as e:
+ import_error = e
+ pass
+
logger = logging.getLogger(__name__)
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
def __init__(
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
):
- try:
- from mf_adapter import TransferEngine
- except ImportError as e:
- raise ImportError(
+ if import_error is not None:
+ logger.warning(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
- ) from e
+ )
+ raise import_error
self.engine = TransferEngine()
self.hostname = hostname
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
self.initialize()
def initialize(self) -> None:
+ from sglang.srt.layers.dp_attention import (
+ get_tensor_model_parallel_world_size,
+ get_tp_group,
+ )
+
+ transfer_protocol = self._get_transfer_protocol()
+ if transfer_protocol is None or transfer_protocol == "sdma":
+ trans_op_type = TransferEngine.TransDataOpType.SDMA
+ else:
+ trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
+ """with device RDMA for PD transfer"""
+ tmp_tensor = torch.zeros(1, device="npu")
+ output_tensor_list = [
+ torch.empty_like(tmp_tensor)
+ for _ in range(get_tensor_model_parallel_world_size())
+ ]
+ # Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
+ torch.distributed.all_gather(
+ output_tensor_list, tmp_tensor, group=get_tp_group().device_group
+ )
"""Initialize the ascend transfer instance."""
ret_value = self.engine.initialize(
- self.store_url,
- self.session_id,
- self.role,
- self.npu_id,
+ self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
)
if ret_value != 0:
logger.error("Ascend Transfer Engine initialization failed.")
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
ret_value = -1
if ret_value != 0:
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
+
+ @staticmethod
+ def _get_transfer_protocol():
+ protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
+ allowed_protocols = {"device_rdma", "sdma"}
+ if protocol and protocol.lower() in allowed_protocols:
+ return protocol.lower()
+ else:
+ logger.warning(
+ "Invalid or no transfer protocol specified, using default protocol."
+ )
+ return None
\ No newline at end of file
diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py
index 82876066f..096a1db59 100644
--- a/python/sglang/srt/disaggregation/common/conn.py
+++ b/python/sglang/srt/disaggregation/common/conn.py
@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
+ @cache
+ def _connect(self, endpoint: str, is_ipv6: bool = False):
+ socket = zmq.Context().socket(zmq.PUSH)
+ if is_ipv6:
+ socket.setsockopt(zmq.IPV6, 1)
+ socket.connect(endpoint)
+ return socket
+
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
socket.connect(endpoint)
return socket
- def get_mha_kv_ptrs_with_pp(
- self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
- ) -> Tuple[List[int], List[int], List[int], List[int], int]:
- # pp is not supported on the decode side yet
- start_layer = self.kv_args.prefill_start_layer
- num_kv_layers = len(src_kv_ptrs) // 2
- end_layer = start_layer + num_kv_layers
- dst_num_total_layers = len(dst_kv_ptrs) // 2
- src_k_ptrs = src_kv_ptrs[:num_kv_layers]
- src_v_ptrs = src_kv_ptrs[num_kv_layers:]
- dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
- dst_v_ptrs = dst_kv_ptrs[
- dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
- ]
- layers_current_pp_stage = len(src_k_ptrs)
- return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
-
- def get_mla_kv_ptrs_with_pp(
- self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
- ) -> Tuple[List[int], List[int], int]:
- # pp is not supported on the decode side yet
- start_layer = self.kv_args.prefill_start_layer
- end_layer = start_layer + len(src_kv_ptrs)
- sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
- layers_current_pp_stage = len(src_kv_ptrs)
- return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
-
class CommonKVSender(BaseKVSender):
diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py
index 1db475f15..f4d7e8f7f 100644
--- a/python/sglang/srt/disaggregation/decode.py
+++ b/python/sglang/srt/disaggregation/decode.py
@@ -609,21 +609,15 @@ class DecodeTransferQueue:
idx = decode_req.metadata_buffer_index
(
output_id,
- cached_tokens,
output_token_logprobs_val,
output_token_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
- output_topk_p,
- output_topk_index,
output_hidden_states,
) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item())
- decode_req.req.cached_tokens = cached_tokens[0].item()
if not self.spec_algorithm.is_none():
- decode_req.req.output_topk_p = output_topk_p
- decode_req.req.output_topk_index = output_topk_index
decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append(
@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
elif prepare_mlp_sync_flag:
batch, _ = self._prepare_idle_batch_and_run(None)
- queue_size = (
+ if batch is None and (
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
- )
- if self.server_args.disaggregation_decode_enable_offload_kvcache:
- queue_size += len(self.decode_offload_manager.ongoing_offload)
-
- if batch is None and queue_size == 0:
+ == 0
+ ):
self.self_check_during_idle()
self.last_batch = batch
@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
)
self.process_batch_result(tmp_batch, tmp_result)
- queue_size = (
+ if batch is None and (
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
- )
- if self.server_args.disaggregation_decode_enable_offload_kvcache:
- queue_size += len(self.decode_offload_manager.ongoing_offload)
-
- if batch is None and queue_size == 0:
+ == 0
+ ):
self.self_check_during_idle()
self.last_batch = batch
@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)
-
- if self.server_args.disaggregation_decode_enable_offload_kvcache:
- self.decode_offload_manager.check_offload_progress()
diff --git a/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py b/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
deleted file mode 100644
index f130c3fbb..000000000
--- a/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
+++ /dev/null
@@ -1,185 +0,0 @@
-import logging
-import threading
-import time
-
-import torch
-
-from sglang.srt.server_args import ServerArgs
-from sglang.srt.managers.cache_controller import HiCacheController
-from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
-from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
-from sglang.srt.mem_cache.memory_pool import (
- MHATokenToKVPool,
- MLATokenToKVPool,
- ReqToTokenPool,
-)
-from sglang.srt.mem_cache.memory_pool_host import (
- MHATokenToKVPoolHost,
- MLATokenToKVPoolHost,
-)
-
-logger = logging.getLogger(__name__)
-
-
-class DecodeKVCacheOffloadManager:
- """Manage decode-side KV cache offloading lifecycle and operations."""
-
- def __init__(
- self,
- req_to_token_pool: ReqToTokenPool,
- token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
- tp_group: torch.distributed.ProcessGroup,
- tree_cache: BasePrefixCache,
- server_args: ServerArgs,
- ) -> None:
- self.req_to_token_pool = req_to_token_pool
- self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
- self.page_size = server_args.page_size
- self.server_args = server_args
- self.request_counter = 0
- self.tree_cache = tree_cache
- kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
- if isinstance(kv_cache, MHATokenToKVPool):
- self.decode_host_mem_pool = MHATokenToKVPoolHost(
- kv_cache,
- server_args.hicache_ratio,
- server_args.hicache_size,
- self.page_size,
- server_args.hicache_mem_layout,
- )
- elif isinstance(kv_cache, MLATokenToKVPool):
- self.decode_host_mem_pool = MLATokenToKVPoolHost(
- kv_cache,
- server_args.hicache_ratio,
- server_args.hicache_size,
- self.page_size,
- server_args.hicache_mem_layout,
- )
- else:
- raise ValueError("Unsupported KV cache type for decode offload")
-
- self.tp_group = tp_group
- self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
- self.cache_controller = HiCacheController(
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
- mem_pool_host=self.decode_host_mem_pool,
- page_size=self.page_size,
- tp_group=tp_group,
- io_backend=server_args.hicache_io_backend,
- load_cache_event=threading.Event(),
- storage_backend=server_args.hicache_storage_backend,
- model_name=server_args.served_model_name,
- storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
- )
-
- self.ongoing_offload = {}
- self.ongoing_backup = {}
- logger.info("Enable offload kv cache for decode side")
-
- def offload_kv_cache(self, req) -> bool:
- """Offload a finished request's KV cache to storage."""
-
- if self.cache_controller is None or self.decode_host_mem_pool is None:
- return False
-
- if req.req_pool_idx == -1:
- return False
-
- token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
- if token_indices.dim() == 0 or token_indices.numel() == 0:
- logger.debug(
- f"Request {req.rid} has invalid token_indices: {token_indices}"
- )
- return False
-
- tokens = req.origin_input_ids + req.output_ids
- aligned_len = (len(tokens) // self.page_size) * self.page_size
- if aligned_len == 0:
- return False
-
- token_indices = token_indices[:aligned_len]
- tokens = tokens[:aligned_len]
-
- # Asynchronously offload KV cache from device to host by cache controller
- self.request_counter += 1
- ack_id = self.request_counter
- host_indices = self.cache_controller.write(
- device_indices=token_indices.long(),
- node_id=ack_id,
- )
- if host_indices is None:
- logger.error(f"Not enough host memory for request {req.rid}")
- return False
-
- self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
- return True
-
- def check_offload_progress(self):
- """Check the progress of offload from device to host and backup from host to storage."""
- cc = self.cache_controller
-
- qsizes = torch.tensor(
- [
- len(cc.ack_write_queue),
- cc.ack_backup_queue.qsize(),
- ],
- dtype=torch.int,
- )
- if self.tp_world_size > 1:
- torch.distributed.all_reduce(
- qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
- )
-
- n_write, n_backup = map(int, qsizes.tolist())
- self._check_offload_progress(n_write)
- self._check_backup_progress(n_backup)
-
- def _check_offload_progress(self, finish_count):
- """Check the progress of offload from device to host."""
- while finish_count > 0:
- _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
- finish_event.synchronize()
- for ack_id in ack_list:
- req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
-
- # Release device
- self.tree_cache.cache_finished_req(req)
-
- # Trigger async backup from host to storage by cache controller
- self._trigger_backup(req.rid, host_indices, tokens, start_time)
- finish_count -= 1
-
- def _check_backup_progress(self, finish_count):
- """Check the progress of backup from host to storage."""
- for _ in range(finish_count):
- storage_operation = self.cache_controller.ack_backup_queue.get()
- ack_id = storage_operation.id
- req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
-
- # Release host memory
- self.decode_host_mem_pool.free(host_indices)
-
- logger.debug(
- f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
- )
-
- def _trigger_backup(self, req_id, host_indices, tokens, start_time):
- """Trigger async backup from host to storage by cache controller."""
-
- # Generate page hashes and write to storage
- page_hashes = self._compute_prefix_hash(tokens)
- ack_id = self.cache_controller.write_storage(
- host_indices,
- tokens,
- hash_value=page_hashes,
- )
- self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
-
- def _compute_prefix_hash(self, tokens):
- last_hash = ""
- page_hashes = []
- for offset in range(0, len(tokens), self.page_size):
- page_tokens = tokens[offset : offset + self.page_size]
- last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
- page_hashes.append(last_hash)
- return page_hashes
diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
index e2ae55780..be0383eec 100644
--- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
req.grammar.finished = req.finished()
self.output_ids = torch.tensor(self.output_ids, device=self.device)
- # Simulate the eagle run.
- if self.spec_algorithm.is_eagle():
+ # Simulate the eagle run. We add mock data to hidden states for the
+ # ease of implementation now meaning the first token will have acc rate
+ # of 0.
+ if not self.spec_algorithm.is_none():
b = len(self.reqs)
- topk = server_args.speculative_eagle_topk
- topk_p = torch.stack(
- [
- torch.as_tensor(
- req.output_topk_p[:topk],
- device=self.device,
- dtype=torch.float32,
- )
- for req in self.reqs
- ],
- dim=0,
+ topk_p = torch.arange(
+ b * server_args.speculative_eagle_topk,
+ 0,
+ -1,
+ device=self.device,
+ dtype=torch.float32,
)
- topk_index = torch.stack(
- [
- torch.as_tensor(
- req.output_topk_index[:topk],
- device=self.device,
- dtype=torch.int64,
- )
- for req in self.reqs
- ],
- dim=0,
+ topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
+ topk_p /= b * server_args.speculative_eagle_topk
+ topk_index = torch.arange(
+ b * server_args.speculative_eagle_topk, device=self.device
)
+ topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py
index b6f12e46e..f779e1fee 100644
--- a/python/sglang/srt/disaggregation/mooncake/conn.py
+++ b/python/sglang/srt/disaggregation/mooncake/conn.py
@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
layers_params = None
# pp is not supported on the decode side yet
+ start_layer = self.kv_args.prefill_start_layer
+ end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
if self.is_mla_backend:
- src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
- self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
- )
+ src_kv_ptrs = self.kv_args.kv_data_ptrs
+ layers_per_pp_stage = len(src_kv_ptrs)
+ dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
dst_kv_ptrs[layer_id],
kv_item_len,
)
- for layer_id in range(layers_current_pp_stage)
+ for layer_id in range(layers_per_pp_stage)
]
else:
- src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
- self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
- )
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
+ dst_num_total_layers = num_kv_layers * self.pp_size
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
+ layers_per_pp_stage = len(src_k_ptrs)
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
+ dst_v_ptrs = dst_kv_ptrs[
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
+ ]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
dst_k_ptrs[layer_id],
kv_item_len,
)
- for layer_id in range(layers_current_pp_stage)
+ for layer_id in range(layers_per_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
kv_item_len,
)
- for layer_id in range(layers_current_pp_stage)
+ for layer_id in range(layers_per_pp_stage)
]
assert layers_params is not None
@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0
- src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
- self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
- )
+ # pp is not supported on the decode side yet
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
+ dst_num_total_layers = num_kv_layers * self.pp_size
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
+ layers_per_pp_stage = len(src_k_ptrs)
+ start_layer = self.pp_rank * layers_per_pp_stage
+ end_layer = start_layer + layers_per_pp_stage
+ dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
+ dst_v_ptrs = dst_kv_ptrs[
+ dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
+ ]
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
- for layer_id in range(layers_current_pp_stage)
+ for layer_id in range(layers_per_pp_stage)
] + [
(
src_v_ptrs[layer_id],
@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
- for layer_id in range(layers_current_pp_stage)
+ for layer_id in range(layers_per_pp_stage)
]
def process_layer_tp_aware(layer_params):
diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py
index 3f794ea3a..5b9255e31 100644
--- a/python/sglang/srt/disaggregation/prefill.py
+++ b/python/sglang/srt/disaggregation/prefill.py
@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
- req.output_topk_p = batch.spec_info.topk_p[i]
- req.output_topk_index = batch.spec_info.topk_index[i]
if self.spec_algorithm.is_eagle3():
req.hidden_states_tensor = (
batch.spec_info.hidden_states[i].cpu().clone()
diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py
index fe4e7fb9f..43770e3e2 100644
--- a/python/sglang/srt/disaggregation/utils.py
+++ b/python/sglang/srt/disaggregation/utils.py
@@ -85,7 +85,7 @@ class MetadataBuffers:
self,
size: int,
hidden_size: int,
- hidden_states_dtype: torch.dtype,
+ dtype: torch.dtype,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
@@ -107,9 +107,7 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
- self.cached_tokens = torch.zeros(
- (size, 16), dtype=torch.int32, device=device
- )
+
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
@@ -122,49 +120,33 @@ class MetadataBuffers:
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device
)
- # For PD + spec decode
- self.output_topk_p = torch.zeros(
- (size, 16), dtype=torch.float32, device=device
- )
- self.output_topk_index = torch.zeros(
- (size, 16), dtype=torch.int64, device=device
- )
self.output_hidden_states = torch.zeros(
- (size, hidden_size), dtype=hidden_states_dtype, device=device
+ (size, hidden_size), dtype=dtype, device=device
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
- self.cached_tokens.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
- self.output_topk_p.data_ptr(),
- self.output_topk_index.data_ptr(),
self.output_hidden_states.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
- self.cached_tokens.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
- self.output_topk_p.nbytes,
- self.output_topk_index.nbytes,
self.output_hidden_states.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
- self.cached_tokens[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
- self.output_topk_p[0].nbytes,
- self.output_topk_index[0].nbytes,
self.output_hidden_states[0].nbytes,
]
return ptrs, data_lens, item_lens
@@ -172,20 +154,16 @@ class MetadataBuffers:
def get_buf(self, idx: int):
return (
self.output_ids[idx],
- self.cached_tokens[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
- self.output_topk_p[idx],
- self.output_topk_index[idx],
self.output_hidden_states[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
- self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -208,17 +186,8 @@ class MetadataBuffers:
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
- # For PD + spec decode
+ # for PD + spec decode
if req.hidden_states_tensor is not None:
- # speculative_eagle_topk should not be greater than 16 currently
- topk = req.output_topk_p.size(0)
-
- self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
- req.output_topk_p
- )
- self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
- req.output_topk_index
- )
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py
index f6a0f597b..66cf2c873 100644
--- a/python/sglang/srt/entrypoints/engine.py
+++ b/python/sglang/srt/entrypoints/engine.py
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version(
"sgl-kernel",
- "0.3.12",
+ "0.3.11",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)
diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py
index 61c1af24f..91c1d9e31 100644
--- a/python/sglang/srt/entrypoints/grpc_request_manager.py
+++ b/python/sglang/srt/entrypoints/grpc_request_manager.py
@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import asyncio
-import copy
import dataclasses
import logging
import os
@@ -12,8 +11,7 @@ import signal
import sys
import threading
import time
-import uuid
-from typing import Any, AsyncGenerator, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
import grpc
import zmq
@@ -81,10 +79,11 @@ class GrpcReqState:
last_completion_tokens: int = 1
# Streaming state
+ last_output_offset: int = 0
stream_finished: bool = False
- input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
- # Token accumulation (for non-streaming)
+ # Output accumulation
+ text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
@@ -140,6 +139,8 @@ class GrpcRequestManager:
self.is_pause_cond = asyncio.Condition()
# Metrics
+ self.request_counter = 0
+ self.request_counter_lock = asyncio.Lock()
self.last_receive_tstamp = time.time()
# Crash dump for debugging
@@ -157,133 +158,22 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
- ) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
+ ) -> asyncio.Queue:
"""
- Submit a generation request to the scheduler with n>1 parallel sampling support.
-
- This method implements the same two-phase approach as tokenizer_manager.py:
- 1. Phase 1: Send prefix caching request (max_new_tokens=0)
- 2. Phase 2: Send n generation requests that reuse the cached prefix
-
- Yields individual responses for streaming, or aggregated responses for non-streaming.
+ Submit a generation request to the scheduler.
+ Returns a queue for streaming outputs.
"""
- n = getattr(obj.sampling_params, "n", 1)
-
- if n <= 1:
- async for response in self._handle_single_request(
- obj, request_id, grpc_context
- ):
- yield response
- return
-
- # N>1 handling - two-phase approach
- logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
-
- # Generate base request ID if not provided
- if request_id is None:
- base_request_id = f"grpc-{uuid.uuid4().hex}"
- else:
- base_request_id = request_id
-
- # Phase 1: Cache the common prefix
- logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
- prefix_obj = copy.copy(obj)
- prefix_obj.sampling_params = copy.copy(obj.sampling_params)
- prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
- prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
-
- # Send prefix caching request and consume response
- async for _ in self._handle_single_request(
- prefix_obj, f"{base_request_id}-prefix", grpc_context
- ):
- # Consume prefix response (usually just one chunk with finish_reason)
- pass
-
- logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
-
- # Phase 2: Generate n parallel requests
- logger.debug(f"Phase 2: Generating {n} parallel requests")
- generators = []
- request_ids = []
-
- for i in range(n):
- # Create individual generation request
- gen_obj = copy.copy(obj)
- gen_obj.sampling_params = copy.copy(obj.sampling_params)
- gen_obj.sampling_params.n = 1 # Each request generates 1 response
-
- gen_request_id = f"{base_request_id}-{i}"
- request_ids.append(gen_request_id)
-
- # Start generation request
- generators.append(
- self._handle_single_request(gen_obj, gen_request_id, grpc_context)
- )
-
- # Handle response aggregation
- is_stream = getattr(obj, "stream", False)
-
- if not is_stream:
- # Non-streaming: collect all responses and return as batch
- logger.debug(f"Non-streaming mode: collecting {n} responses")
- responses = []
- for generator in generators:
- async for response in generator:
- responses.append(response)
- yield responses # Return all responses as a batch
- else:
- # Streaming mode: multiplex responses with index for ordering
- logger.debug(f"Streaming mode: multiplexing {n} streams")
- rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
-
- # Create async tasks for all generators
- task_map = {}
- for generator in generators:
- task = asyncio.create_task(generator.__anext__())
- task_map[task] = generator
-
- # Process responses as they arrive
- while task_map:
- done, _ = await asyncio.wait(
- task_map.keys(), return_when=asyncio.FIRST_COMPLETED
- )
-
- for task in done:
- generator = task_map.pop(task)
- try:
- response = await task
-
- # Add index for client-side ordering
- if isinstance(response, dict) and "meta_info" in response:
- response_rid = response["meta_info"].get("id", "")
- if response_rid in rid_to_index:
- response["index"] = rid_to_index[response_rid]
-
- yield response
-
- # Create next task for this generator
- next_task = asyncio.create_task(generator.__anext__())
- task_map[next_task] = generator
-
- except StopAsyncIteration:
- # This generator is finished
- pass
-
- async def _handle_single_request(
- self,
- obj: TokenizedGenerateReqInput,
- request_id: Optional[str] = None,
- grpc_context: Optional[grpc.aio.ServicerContext] = None,
- ):
- """Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided
if request_id is None:
- request_id = f"grpc-{uuid.uuid4().hex}"
+ async with self.request_counter_lock:
+ request_id = f"grpc-{self.request_counter}"
+ self.request_counter += 1
obj.rid = request_id
- # Create and register request state
# TODO: support log_request
+
+ # Create request state
state = GrpcReqState(
request_id=request_id,
grpc_context=grpc_context,
@@ -299,51 +189,19 @@ class GrpcRequestManager:
state.session_id = obj.session_params.session_id
state.is_session_request = True
+ # Register state
self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj)
+ # Send to scheduler via ZMQ
try:
- # Send to scheduler - let exceptions bubble up to grpc_server.py
await self._send_to_scheduler(obj)
-
- is_stream = getattr(obj, "stream", False)
-
- while True:
- # Client cancelled - notify scheduler and exit
- if grpc_context and grpc_context.cancelled():
- await self.abort_request(request_id)
- return
-
- try:
- response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
-
- if is_stream:
- yield response
-
- # Non-streaming: yield final response with accumulated tokens from state
- if isinstance(response, dict) and response.get("finished", False):
- if not is_stream:
- final_response = response.copy()
- final_response["token_ids"] = state.output_ids
- yield final_response
- break
-
- except asyncio.TimeoutError:
- # Timeout waiting for response - abort and cleanup
- logger.warning(
- f"Timeout waiting for response for request {request_id}"
- )
- await self.abort_request(request_id)
- return
-
- finally:
- # Always clean up request state when exiting
- self._cleanup_request_state(request_id)
-
- def _cleanup_request_state(self, request_id: str):
- """Clean up local request state (does not notify scheduler)."""
- if request_id in self.rid_to_state:
+ except Exception as e:
+ # Clean up on failure
del self.rid_to_state[request_id]
+ raise RuntimeError(f"Failed to send request to scheduler: {e}")
+
+ return state.out_queue
async def embedding_request(
self,
@@ -356,7 +214,9 @@ class GrpcRequestManager:
"""
# Generate request ID if not provided
if request_id is None:
- request_id = f"grpc-embed-{uuid.uuid4().hex}"
+ async with self.request_counter_lock:
+ request_id = f"grpc-embed-{self.request_counter}"
+ self.request_counter += 1
obj.rid = request_id
@@ -495,6 +355,7 @@ class GrpcRequestManager:
# Extract output for this request
output_data = {
"request_id": rid,
+ "text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None,
"meta_info": {
@@ -506,9 +367,6 @@ class GrpcRequestManager:
if batch_out.completion_tokens
else 0
),
- "cached_tokens": (
- batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
- ),
"finish_reason": (
str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i]
@@ -517,110 +375,29 @@ class GrpcRequestManager:
},
}
- # Accumulate input logprobs (only once, usually in first chunk)
- if batch_out.input_token_logprobs_val and i < len(
- batch_out.input_token_logprobs_val
- ):
- if not state.input_token_logprobs_val:
- state.input_token_logprobs_val.extend(
- batch_out.input_token_logprobs_val[i]
- )
- if batch_out.input_token_logprobs_idx and i < len(
- batch_out.input_token_logprobs_idx
- ):
- state.input_token_logprobs_idx.extend(
- batch_out.input_token_logprobs_idx[i]
- )
- if batch_out.input_top_logprobs_val and i < len(
- batch_out.input_top_logprobs_val
- ):
- state.input_top_logprobs_val.extend(
- batch_out.input_top_logprobs_val[i]
- )
- if batch_out.input_top_logprobs_idx and i < len(
- batch_out.input_top_logprobs_idx
- ):
- state.input_top_logprobs_idx.extend(
- batch_out.input_top_logprobs_idx[i]
- )
-
- # Send input logprobs based on mode
- if state.input_token_logprobs_val:
- if state.obj.stream and not state.input_logprobs_sent:
- # Streaming: send input logprobs once in first chunk that has them
- output_data["input_logprobs"] = {
- "token_logprobs_val": state.input_token_logprobs_val,
- "token_logprobs_idx": state.input_token_logprobs_idx,
- "top_logprobs_val": state.input_top_logprobs_val,
- "top_logprobs_idx": state.input_top_logprobs_idx,
- }
- state.input_logprobs_sent = True
- elif not state.obj.stream and output_data["finished"]:
- # Non-streaming: send input logprobs in final chunk
- output_data["input_logprobs"] = {
- "token_logprobs_val": state.input_token_logprobs_val,
- "token_logprobs_idx": state.input_token_logprobs_idx,
- "top_logprobs_val": state.input_top_logprobs_val,
- "top_logprobs_idx": state.input_top_logprobs_idx,
- }
-
- # Add output logprobs if available (RAW - no detokenization!)
+ # Add logprobs if available
if batch_out.output_token_logprobs_val and i < len(
batch_out.output_token_logprobs_val
):
- # Accumulate in state first
- state.output_token_logprobs_val.extend(
- batch_out.output_token_logprobs_val[i]
- )
- if batch_out.output_token_logprobs_idx and i < len(
- batch_out.output_token_logprobs_idx
- ):
- state.output_token_logprobs_idx.extend(
- batch_out.output_token_logprobs_idx[i]
- )
- if batch_out.output_top_logprobs_val and i < len(
- batch_out.output_top_logprobs_val
- ):
- state.output_top_logprobs_val.extend(
+ output_data["logprobs"] = {
+ "tokens": batch_out.output_token_logprobs_val[i],
+ "top_logprobs": (
batch_out.output_top_logprobs_val[i]
- )
- if batch_out.output_top_logprobs_idx and i < len(
- batch_out.output_top_logprobs_idx
- ):
- state.output_top_logprobs_idx.extend(
- batch_out.output_top_logprobs_idx[i]
- )
+ if batch_out.output_top_logprobs_val
+ and i < len(batch_out.output_top_logprobs_val)
+ else None
+ ),
+ }
- if state.obj.stream:
- # For streaming: send incremental logprobs (only new tokens in this chunk)
- # NOTE: this is different than TokenizerManager, which always accumulates
- def get_part(attr_name):
- source_list = getattr(batch_out, attr_name, None)
- return (
- source_list[i]
- if source_list and i < len(source_list)
- else []
- )
+ # Update state
+ if output_data["text"]:
+ state.text += output_data["text"][state.last_output_offset :]
+ state.last_output_offset = len(output_data["text"])
- output_data["output_logprobs"] = {
- "token_logprobs_val": batch_out.output_token_logprobs_val[i],
- "token_logprobs_idx": get_part("output_token_logprobs_idx"),
- "top_logprobs_val": get_part("output_top_logprobs_val"),
- "top_logprobs_idx": get_part("output_top_logprobs_idx"),
- }
- elif output_data["finished"]:
- # Non-streaming: send cumulative output logprobs in final chunk
- output_data["output_logprobs"] = {
- "token_logprobs_val": state.output_token_logprobs_val,
- "token_logprobs_idx": state.output_token_logprobs_idx,
- "top_logprobs_val": state.output_top_logprobs_val,
- "top_logprobs_idx": state.output_top_logprobs_idx,
- }
-
- # Update state for accumulation
if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"])
+ # Send to output queue
await state.out_queue.put(output_data)
# Handle completion
diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py
index b772f3067..f7edf7743 100644
--- a/python/sglang/srt/entrypoints/grpc_server.py
+++ b/python/sglang/srt/entrypoints/grpc_server.py
@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request)
- # Submit to request manager (automatically handles n>1)
- response_generator = self.request_manager.generate_request(
+ # Submit to request manager
+ output_queue = await self.request_manager.generate_request(
obj=tokenized_req,
request_id=request.request_id,
grpc_context=context,
)
- async for output in response_generator:
- # Handle batch responses (for n>1 non-streaming)
- if isinstance(output, list):
- for batch_output in output:
- if "error" in batch_output:
- yield sglang_scheduler_pb2.GenerateResponse(
- request_id=request.request_id,
- error=sglang_scheduler_pb2.GenerateError(
- message=batch_output["error"],
- http_status_code=(
- "500" if "abort" not in batch_output else "499"
- ),
- ),
- )
- else:
- # All non-error batch outputs are final responses
- yield self._create_completion_response(
- request.request_id, batch_output
- )
- else:
- # Handle single response (for streaming or n=1 non-streaming)
+ # Stream outputs
+ while True:
+ try:
+ # Get output with timeout
+ output = await asyncio.wait_for(output_queue.get(), timeout=4)
+
+ # Check for errors
if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
),
),
)
- elif output.get("finished", False):
+ break
+
+ # Check if finished
+ if output.get("finished", False):
+ # Send completion
yield self._create_completion_response(
request.request_id, output
)
+ break
else:
+ # Send chunk
yield self._create_chunk_response(request.request_id, output)
+ except asyncio.TimeoutError:
+ # Check if context is still active
+ if context.cancelled():
+ # Abort the request
+ await self.request_manager.abort_request(request.request_id)
+ break
+ continue
+
except Exception as e:
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
yield sglang_scheduler_pb2.GenerateResponse(
@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
prompt_tokens=result.get("prompt_tokens", 0),
cached_tokens=0,
embedding_dim=len(result["embedding"]),
+ generation_time=time.time() - self.start_time,
),
)
@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger.info(f"Sending health check request to request manager...")
# Submit and wait for response
- output_generator = self.request_manager.generate_request(
+ output_queue = await self.request_manager.generate_request(
health_request, request_id=rid
)
try:
- # Get first response with timeout
+ # Wait for response with configurable timeout
response = await asyncio.wait_for(
- output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
+ output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
)
# Clean up
@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0,
- stream=grpc_req.stream or False,
- lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
+ stream=True, # Always stream for gRPC
+ lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
),
@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
regex = None
json_schema = None
ebnf_grammar = None
- structural_tag = None
if grpc_params.HasField("regex"):
regex = grpc_params.regex
@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
json_schema = grpc_params.json_schema
elif grpc_params.HasField("ebnf_grammar"):
ebnf_grammar = grpc_params.ebnf_grammar
- elif grpc_params.HasField("structural_tag"):
- structural_tag = grpc_params.structural_tag
return SGLSamplingParams(
temperature=grpc_params.temperature or 1.0,
@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
repetition_penalty=grpc_params.repetition_penalty or 1.0,
max_new_tokens=grpc_params.max_new_tokens or 128,
min_new_tokens=grpc_params.min_new_tokens or 0,
- stop=list(grpc_params.stop) if grpc_params.stop else [],
+ stop=list(grpc_params.stop) if grpc_params.stop else None,
stop_token_ids=(
- list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
+ list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
),
skip_special_tokens=grpc_params.skip_special_tokens,
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
regex=regex,
json_schema=json_schema,
ebnf=ebnf_grammar,
- structural_tag=structural_tag,
n=grpc_params.n or 1,
ignore_eos=grpc_params.ignore_eos,
)
- def _convert_logprobs_to_proto(
- self, logprobs_data: Dict
- ) -> Optional[sglang_scheduler_pb2.LogProbs]:
- """Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
- if not logprobs_data:
- return None
-
- token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
- token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
- top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
- top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
-
- # Build TopLogProbs entries
- top_logprobs_proto = []
- if top_logprobs_val and top_logprobs_idx:
- for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
- top_logprobs_proto.append(
- sglang_scheduler_pb2.TopLogProbs(
- values=val_list,
- token_ids=idx_list,
- )
- )
-
- return sglang_scheduler_pb2.LogProbs(
- token_logprobs=token_logprobs_val,
- token_ids=token_logprobs_idx,
- top_logprobs=top_logprobs_proto,
- )
-
def _create_chunk_response(
self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a streaming chunk response."""
- meta_info = output.get("meta_info", {})
-
- # Convert output logprobs if present
- output_logprobs_proto = self._convert_logprobs_to_proto(
- output.get("output_logprobs")
- )
-
- # Convert input logprobs if present (only in first chunk)
- input_logprobs_proto = self._convert_logprobs_to_proto(
- output.get("input_logprobs")
- )
-
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
- token_ids=output.get("token_ids", []),
- prompt_tokens=meta_info.get("prompt_tokens", 0),
- completion_tokens=meta_info.get("completion_tokens", 0),
- cached_tokens=meta_info.get("cached_tokens", 0),
- output_logprobs=output_logprobs_proto,
- input_logprobs=input_logprobs_proto,
+ token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
+ text=output.get("text", ""),
+ prompt_tokens=0,
+ completion_tokens=len(output.get("token_ids", [])),
+ cached_tokens=0,
+ generation_time=time.time() - self.start_time,
+ queue_time=0.0,
),
)
@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response."""
- # Extract meta info and finish reason details
+ # Determine finish reason
+ finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
meta_info = output.get("meta_info", {})
- finish_reason_data = meta_info.get("finish_reason")
-
- # Determine finish reason, default is stop
- finish_reason = "stop"
- if finish_reason_data:
- if isinstance(finish_reason_data, dict):
- finish_reason_type = finish_reason_data.get("type")
- else:
- # Handle legacy string format
- finish_reason_type = finish_reason_data
-
- if finish_reason_type == "length":
- finish_reason = "length"
- elif finish_reason_type == "abort":
- finish_reason = "abort"
-
- # Extract matched_stop information
- matched_stop_kwargs = {}
- if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
- matched = finish_reason_data["matched"]
- if isinstance(matched, int):
- matched_stop_kwargs["matched_token_id"] = matched
- elif isinstance(matched, str):
- matched_stop_kwargs["matched_stop_str"] = matched
-
- # Convert output logprobs if present
- output_logprobs_proto = self._convert_logprobs_to_proto(
- output.get("output_logprobs")
- )
-
- # Convert input logprobs if present
- input_logprobs_proto = self._convert_logprobs_to_proto(
- output.get("input_logprobs")
- )
+ if meta_info.get("finish_reason") == "length":
+ finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
+ elif meta_info.get("finish_reason") == "eos_token":
+ finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
complete=sglang_scheduler_pb2.GenerateComplete(
output_ids=output.get("token_ids", []),
+ output_text=output.get("text", ""),
finish_reason=finish_reason,
- prompt_tokens=meta_info.get("prompt_tokens", 0),
- completion_tokens=meta_info.get(
- "completion_tokens", len(output.get("token_ids", []))
- ),
- cached_tokens=meta_info.get("cached_tokens", 0),
- output_logprobs=output_logprobs_proto,
- input_logprobs=input_logprobs_proto,
- **matched_stop_kwargs,
),
)
diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py
index 5a0a387c8..23830d86c 100644
--- a/python/sglang/srt/entrypoints/openai/protocol.py
+++ b/python/sglang/srt/entrypoints/openai/protocol.py
@@ -16,7 +16,7 @@
import time
import uuid
from dataclasses import dataclass
-from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
+from typing import Any, Dict, List, Optional, TypeAlias, Union
from openai.types.responses import (
ResponseFunctionToolCall,
@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
- # Extra key for classifying the request (e.g. cache_salt)
- extra_key: Optional[Union[List[str], str]] = None
- # Cache salt for request caching
- cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
- # For custom metric labels
- custom_labels: Optional[Dict[str, str]] = None
+ # For customer metric labels
+ customer_labels: Optional[Dict[str, str]] = None
@field_validator("max_tokens")
@classmethod
@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
- arguments: Optional[str | Dict[str, Any]] = None
+ arguments: Optional[str] = None
class ToolCall(BaseModel):
@@ -392,7 +388,7 @@ class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
- name: str
+ name: Optional[str] = None
parameters: Optional[object] = None
strict: bool = False
@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
- # Extra key for classifying the request (e.g. cache_salt)
- extra_key: Optional[Union[List[str], str]] = None
- # Cache salt for request caching
- cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
)
priority: int = Field(default=0, description="Request priority")
- extra_key: Optional[str] = Field(
- default=None,
- description="Extra key for classifying the request (e.g. cache_salt)",
- )
- cache_salt: Optional[str] = Field(
- default=None, description="Cache salt for request caching"
- )
# SGLang-specific sampling parameters
frequency_penalty: float = 0.0
@@ -943,16 +928,6 @@ class MessageProcessingResult:
tool_call_constraint: Optional[Any] = None
-class ToolCallProcessingResult(NamedTuple):
- """Result of processing tool calls in a response."""
-
- tool_calls: Optional[
- List[Any]
- ] # List of ToolCall objects or None if parsing failed
- remaining_text: str # Text remaining after parsing tool calls
- finish_reason: Dict[str, Any] # Updated finish reason dictionary
-
-
class ResponseReasoningTextContent(BaseModel):
text: str
type: Literal["reasoning_text"] = "reasoning_text"
diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py
index 2e027fd48..5bc505108 100644
--- a/python/sglang/srt/entrypoints/openai/serving_base.py
+++ b/python/sglang/srt/entrypoints/openai/serving_base.py
@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
self.tokenizer_manager = tokenizer_manager
self.allowed_custom_labels = (
set(
- self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
+ self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
)
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
- and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
+ and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
else None
)
@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
return self.create_error_response(
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
)
- except ValueError as e:
- return self.create_error_response(
- message=str(e),
- err_type="BadRequest",
- status_code=400,
- )
except Exception as e:
logger.exception(f"Error in request: {e}")
return self.create_error_response(
@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
- def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
- """Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
- parts = []
- for key in ["cache_salt", "extra_key"]:
- value = getattr(request, key, None)
- if value:
- if not isinstance(value, str):
- raise TypeError(
- f"Value of {key} must be a string, but got {type(value).__name__}"
- )
- parts.append(value)
- return "".join(parts) if parts else None
-
@abstractmethod
def _convert_to_internal_request(
self,
@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
)
return json.dumps({"error": error.model_dump()})
- def extract_custom_labels(self, raw_request):
+ def extract_customer_labels(self, raw_request):
if (
not self.allowed_custom_labels
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
):
return None
- custom_labels = None
+ customer_labels = None
header = (
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
)
@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
raw_labels = None
if isinstance(raw_labels, dict):
- custom_labels = {
+ customer_labels = {
label: value
for label, value in raw_labels.items()
if label in self.allowed_custom_labels
}
- return custom_labels
+ return customer_labels
diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py
index 13e40a19c..8bd57fc9e 100644
--- a/python/sglang/srt/entrypoints/openai/serving_chat.py
+++ b/python/sglang/srt/entrypoints/openai/serving_chat.py
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
-from jsonschema import Draft202012Validator, SchemaError
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
LogProbs,
MessageProcessingResult,
ToolCall,
- ToolCallProcessingResult,
- ToolChoice,
TopLogprob,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
-from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.function_call_parser import FunctionCallParser
-from sglang.srt.function_call.json_array_parser import JsonArrayParser
-from sglang.srt.function_call.utils import get_json_schema_constraint
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
- self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
def _request_id_prefix(self) -> str:
return "chatcmpl-"
@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
):
return "Tools cannot be empty if tool choice is set to required."
- if request.tool_choice is not None and not isinstance(request.tool_choice, str):
- if not request.tools:
- return "Tools cannot be empty if tool choice is set to a specific tool."
- tool_name = request.tool_choice.function.name
- tool_exists = any(tool.function.name == tool_name for tool in request.tools)
- if not tool_exists:
- return f"Tool '{tool_name}' not found in tools list."
-
- # Validate tool definitions
- for i, tool in enumerate(request.tools or []):
- if tool.function.parameters is None:
- continue
- try:
- Draft202012Validator.check_schema(tool.function.parameters)
- except SchemaError as e:
- return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
-
max_output_tokens = request.max_completion_tokens or request.max_tokens
server_context_length = self.tokenizer_manager.server_args.context_length
if (
@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
- # Extract custom labels from raw request headers
- custom_labels = self.extract_custom_labels(raw_request)
+ # Extract customer labels from raw request headers
+ customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
- extra_key=self._compute_extra_key(request),
priority=request.priority,
- custom_labels=custom_labels,
+ customer_labels=customer_labels,
)
return adapted_request, request
@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
- # Handle JSON schema constraint directly for required or named tool choice
- if request.tool_choice == "required" or isinstance(
- request.tool_choice, ToolChoice
- ):
- json_schema = get_json_schema_constraint(
- request.tools, request.tool_choice
- )
- tool_call_constraint = ("json_schema", json_schema)
# Use chat template
if self.template_manager.chat_template_name is None:
@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
- elif constraint_type == "json_schema":
- sampling_params[constraint_type] = convert_json_schema_to_str(
- constraint_value
- )
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
- if self.reasoning_parser and request.separate_reasoning:
+ if (
+ self.tokenizer_manager.server_args.reasoning_parser
+ and request.separate_reasoning
+ ):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
)
@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content
reasoning_text = None
- reasoning_parser = self.reasoning_parser
+ reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning:
is_force_reasoning = (
self.template_manager.force_reasoning
@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
and request.tools
and self.tool_call_parser
):
- history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
tool_calls, text, finish_reason = self._process_tool_calls(
- text,
- request.tools,
- finish_reason,
- request.tool_choice,
- history_tool_calls_cnt,
+ text, request.tools, finish_reason
)
choice_data = ChatCompletionResponseChoice(
@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
return ChoiceLogprobs(content=token_logprobs)
- def _process_tool_call_id(
- self,
- call_item: ToolCallItem,
- history_tool_calls_cnt: int,
- ) -> str:
- """Process for generating a new and unique `tool_call_id`"""
- if self.tool_call_parser != "kimi_k2":
- # A simple uuid is sufficient for all models except for Kimi-K2.
- tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
- return tool_call_id
- else:
- # Align with Kimi-K2 format: functions.{name}:{index}
- # Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
- # Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
- tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
- logger.debug(
- f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
- )
- return tool_call_id
-
def _process_tool_calls(
self,
text: str,
tools: List[Any],
finish_reason: Dict[str, Any],
- tool_choice: Optional[Union[str, ToolChoice]] = None,
- history_tool_calls_cnt: int = 0,
- ) -> ToolCallProcessingResult:
+ ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
"""Process tool calls in the response"""
-
- # Handle required or named tool choice
- if tool_choice == "required" or (
- isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
- ):
- # Set finish reason to tool_calls since we're processing tool calls
- if finish_reason["type"] == "stop":
- finish_reason["type"] = "tool_calls"
- finish_reason["matched"] = None
- try:
- # For required tool choice, we expect a JSON array of tool calls
- tool_call_data = json.loads(text)
- tool_calls = []
- for i, tool in enumerate(tool_call_data):
- # Create a ToolCallItem from the JSON data
- call_info = ToolCallItem(
- tool_index=i, # Use the loop index as tool_index
- name=tool["name"],
- parameters=json.dumps(tool["parameters"], ensure_ascii=False),
- )
- tool_id = self._process_tool_call_id(
- call_info, history_tool_calls_cnt
- )
- tool_calls.append(
- ToolCall(
- id=tool_id,
- index=i,
- function=FunctionResponse(
- name=tool["name"],
- arguments=json.dumps(
- tool["parameters"], ensure_ascii=False
- ),
- ),
- )
- )
- return ToolCallProcessingResult(tool_calls, "", finish_reason)
- except json.JSONDecodeError as e:
- logger.error(f"Tool call parsing error: {e}")
- return ToolCallProcessingResult(None, text, finish_reason)
-
- # Use parser since output is not constrained by JSON schema
parser = FunctionCallParser(tools, self.tool_call_parser)
if parser.has_tool_call(text):
if finish_reason["type"] == "stop":
@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
text, call_info_list = parser.parse_non_stream(text)
tool_calls = []
for call_info in call_info_list:
- tool_id = self._process_tool_call_id(
- call_info, history_tool_calls_cnt
- )
+ # For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
+ if (
+ self.tool_call_parser == "kimi_k2"
+ and call_info.name is not None
+ ):
+ tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
+ else:
+ tool_id = f"call_{uuid.uuid4().hex[:24]}"
+
tool_calls.append(
ToolCall(
id=tool_id,
@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
),
)
)
- return ToolCallProcessingResult(tool_calls, text, finish_reason)
+ return tool_calls, text, finish_reason
except Exception as e:
logger.error(f"Tool call parsing error: {e}")
# Return error but don't fail the whole request
- return ToolCallProcessingResult(None, text, finish_reason)
+ return None, text, finish_reason
- return ToolCallProcessingResult(None, text, finish_reason)
+ return None, text, finish_reason
def _process_streaming_logprobs(
self, content: Dict[str, Any], n_prev_token: int
@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
or self._get_enable_thinking_from_request(request)
)
reasoning_parser_dict[index] = ReasoningParser(
- self.reasoning_parser,
+ self.tokenizer_manager.server_args.reasoning_parser,
request.stream_reasoning,
is_force_reasoning,
)
reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta)
- def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
- """Counts the number of tool calls in the request's message history.
-
- NOTE: This method is only useful for models that include self-increasing
- history tool call idx in tool calls id, such as kimi-k2
-
- Args:
- request: The chat completion request object.
-
- Returns:
- The total number of tool calls in the history, or 0 if not applicable.
- """
- messages = getattr(request, "messages", [])
- idx = 0
- for msg in messages:
- if msg.role == "assistant":
- tool_calls = getattr(msg, "tool_calls", None)
- idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
- return idx
-
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
"""
if hasattr(request, "chat_template_kwargs") and request.chat_template_kwargs:
# For Qwen3 models, `enable_thinking` is supported.
- if self.reasoning_parser in ["qwen3", "glm45"]:
- return request.chat_template_kwargs.get("enable_thinking", False)
+ if request.chat_template_kwargs.get("enable_thinking") is not None:
+ return request.chat_template_kwargs.get("enable_thinking")
# For DeepSeek-V3.1 models, `thinking` is supported.
- elif self.reasoning_parser in ["deepseek-v3"]:
- return request.chat_template_kwargs.get("thinking", False)
+ elif request.chat_template_kwargs.get("thinking") is not None:
+ return request.chat_template_kwargs.get("thinking")
else:
return False
return False
@@ -1068,25 +952,13 @@ class OpenAIServingChat(OpenAIServingBase):
):
"""Process tool calls in streaming response"""
if index not in parser_dict:
- # Use JSON detector directly for required or named tool choice
- if request.tool_choice == "required" or isinstance(
- request.tool_choice, ToolChoice
- ):
- parser_dict[index] = JsonArrayParser()
- else:
- parser_dict[index] = FunctionCallParser(
- tools=request.tools,
- tool_call_parser=self.tool_call_parser,
- )
-
+ parser_dict[index] = FunctionCallParser(
+ tools=request.tools,
+ tool_call_parser=self.tool_call_parser,
+ )
parser = parser_dict[index]
- # Handle both FunctionCallParser and JsonArrayParser
- if isinstance(parser, JsonArrayParser):
- result = parser.parse_streaming_increment(delta, request.tools)
- normal_text, calls = result.normal_text, result.calls
- else:
- normal_text, calls = parser.parse_stream_chunk(delta)
+ normal_text, calls = parser.parse_stream_chunk(delta)
# Yield normal text
if normal_text:
@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls
- history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
for call_item in calls:
# Mark that this choice has tool calls
has_tool_calls[index] = True
@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
- tool_call_id = self._process_tool_call_id(
- call_item, history_tool_calls_cnt
- )
+ if self.tool_call_parser == "kimi_k2":
+ # Align with Kimi-K2 format: functions.{name}:{index}
+ tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
+ else:
+ tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas
@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
def _check_for_unstreamed_tool_args(
self,
- parser: Union[FunctionCallParser, JsonArrayParser],
+ parser: FunctionCallParser,
content: Dict[str, Any],
request: ChatCompletionRequest,
index: int,
@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
"""
- # Get the detector - either from FunctionCallParser or directly if json detector
- detector = parser.detector if hasattr(parser, "detector") else parser
-
- # Only check if we have tool calls and the detector has tracked data
+ # Only check if we have tool calls and the parser has tracked data
if (
- not hasattr(detector, "prev_tool_call_arr")
- or not detector.prev_tool_call_arr
+ not hasattr(parser.detector, "prev_tool_call_arr")
+ or not parser.detector.prev_tool_call_arr
):
return None
if (
- not hasattr(detector, "streamed_args_for_tool")
- or not detector.streamed_args_for_tool
+ not hasattr(parser.detector, "streamed_args_for_tool")
+ or not parser.detector.streamed_args_for_tool
):
return None
# Get the last tool call that was being processed
- tool_index = len(detector.prev_tool_call_arr) - 1
- if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):
+ tool_index = len(parser.detector.prev_tool_call_arr) - 1
+ if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
return None
# Get expected vs actual arguments
- expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {})
+ expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
+ "arguments", {}
+ )
expected_call = json.dumps(expected_args, ensure_ascii=False)
- actual_call = detector.streamed_args_for_tool[tool_index]
+ actual_call = parser.detector.streamed_args_for_tool[tool_index]
# Check if there are remaining arguments to send
remaining_call = (
diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py
index b065984aa..6aa4fe19e 100644
--- a/python/sglang/srt/entrypoints/openai/serving_completions.py
+++ b/python/sglang/srt/entrypoints/openai/serving_completions.py
@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
else:
prompt_kwargs = {"input_ids": prompt}
- # Extract custom labels from raw request headers
- custom_labels = self.extract_custom_labels(raw_request)
+ # Extract customer labels from raw request headers
+ customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
- extra_key=self._compute_extra_key(request),
priority=request.priority,
- custom_labels=custom_labels,
+ customer_labels=customer_labels,
)
return adapted_request, request
diff --git a/python/sglang/srt/entrypoints/openai/serving_responses.py b/python/sglang/srt/entrypoints/openai/serving_responses.py
index 5e965e3bb..3f7619678 100644
--- a/python/sglang/srt/entrypoints/openai/serving_responses.py
+++ b/python/sglang/srt/entrypoints/openai/serving_responses.py
@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params,
stream=request.stream,
rid=request.request_id,
- extra_key=self._compute_extra_key(request),
background=request.background,
)
@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params,
stream=adapted_request.stream,
rid=request_id,
- extra_key=adapted_request.extra_key,
return_logprob=adapted_request.return_logprob,
logprob_start_len=adapted_request.logprob_start_len,
top_logprobs_num=adapted_request.top_logprobs_num,
diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py
index 4db273781..ee5f2c7ca 100644
--- a/python/sglang/srt/eplb/expert_location.py
+++ b/python/sglang/srt/eplb/expert_location.py
@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map(
- server_args=server_args,
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
# TODO optimize performance (rewrite and/or run in separate process with overlap)
def compute_logical_to_rank_dispatch_physical_map(
- server_args: ServerArgs,
logical_to_all_physical_map: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
):
r = random.Random(seed)
- num_local_gpu_physical_experts = num_physical_experts // num_gpus
- num_gpus_per_node = server_args.ep_size // server_args.nnodes
- num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
+ num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype
@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
- physical_expert_id, num_local_gpu_physical_experts
+ physical_expert_id, num_local_physical_experts
)
== gpu_id
]
if len(same_gpu_physical_expert_ids) > 0:
- # 1. Prefer same-GPU experts
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
- else:
- # 2. Otherwise, prefer same-node experts
- node_id = gpu_id // num_gpus_per_node
- same_node_physical_expert_ids = [
- physical_expert_id
- for physical_expert_id in candidate_physical_expert_ids
- if _compute_node_id_of_physical_expert(
- physical_expert_id, num_local_node_physical_experts
- )
- == node_id
- ]
- if len(same_node_physical_expert_ids) > 0:
- output_partial[gpu_id] = same_node_physical_expert_ids[0]
- # 3. Fill remaining slots with fair random choices
num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
def _compute_gpu_id_of_physical_expert(
- physical_expert_id: int, num_local_gpu_physical_experts: int
+ physical_expert_id: int, num_local_physical_experts: int
) -> int:
- return physical_expert_id // num_local_gpu_physical_experts
-
-
-def _compute_node_id_of_physical_expert(
- physical_expert_id: int, num_local_host_physical_experts: int
-) -> int:
- return physical_expert_id // num_local_host_physical_experts
+ return physical_expert_id // num_local_physical_experts
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py
index e568d77fa..e28f4f5cf 100644
--- a/python/sglang/srt/function_call/function_call_parser.py
+++ b/python/sglang/srt/function_call/function_call_parser.py
@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.function_call.step3_detector import Step3Detector
-from sglang.srt.function_call.utils import get_json_schema_constraint
logger = logging.getLogger(__name__)
@@ -179,8 +178,8 @@ class FunctionCallParser:
strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
- json_schema = get_json_schema_constraint(self.tools, tool_choice)
- return ("json_schema", json_schema)
+ ebnf = self.get_ebnf(tool_choice)
+ return ("ebnf", ebnf) if ebnf is not None else None
def get_ebnf(
self, tool_choice: Union[ToolChoice, Literal["required"]]
diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py
index 845b5d41f..6e89fe0a1 100644
--- a/python/sglang/srt/function_call/glm4_moe_detector.py
+++ b/python/sglang/srt/function_call/glm4_moe_detector.py
@@ -39,7 +39,7 @@ def parse_arguments(json_value):
class Glm4MoeDetector(BaseFormatDetector):
"""
- Detector for GLM-4.5 and GLM-4.6 models.
+ Detector for GLM-4.5 models.
Assumes function call format:
get_weather\ncity\n北京\ndate\n2024-06-27\n\nget_weather\ncity\n上海\ndate\n2024-06-27\n
"""
@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self.func_arg_regex = r"(.*?)\s*(.*?)"
def has_tool_call(self, text: str) -> bool:
- """Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
+ """Check if the text contains a glm-4.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
- Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.
+ Streaming incremental parsing tool calls for GLM-4.5 format.
"""
self._buffer += new_text
current_text = self._buffer
diff --git a/python/sglang/srt/function_call/json_array_parser.py b/python/sglang/srt/function_call/json_array_parser.py
deleted file mode 100644
index 5144cb83b..000000000
--- a/python/sglang/srt/function_call/json_array_parser.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import json
-import re
-from typing import List
-
-from sglang.srt.entrypoints.openai.protocol import Tool
-from sglang.srt.function_call.base_format_detector import BaseFormatDetector
-from sglang.srt.function_call.core_types import StreamingParseResult
-
-
-class JsonArrayParser(BaseFormatDetector):
- """
- Parser for JSON array tool calls when JSON schema constraints are active.
-
- This parser is used when tool_choice="required" or a specific tool is named,
- bypassing model-specific parsers in favor of direct JSON array parsing.
- """
-
- def __init__(self):
- super().__init__()
- # Configure for JSON array parsing
- self.bot_token = "["
- self.eot_token = "]"
- self.tool_call_separator = ","
-
- def has_tool_call(self, text: str) -> bool:
- """
- Check if the given text contains a JSON tool call (array or single object).
- """
- return "[" in text or "{" in text
-
- def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
- """
- Parse JSON tool calls using the base class implementation.
- """
- raise NotImplementedError(
- "Detect and parse not supported for JSON schema constraints."
- )
-
- def build_ebnf(self, tools: List[Tool]) -> str:
- """
- Build an EBNF grammar for constrained generation.
- This is not used for JSON schema constraints as they are handled
- by the constraint backends directly.
- """
- raise NotImplementedError(
- "EBNF generation is not supported for JSON schema constraints."
- )
-
- def parse_streaming_increment(
- self, new_text: str, tools: List[Tool]
- ) -> StreamingParseResult:
- """
- Streaming incremental parsing with tool validation.
- """
- return super().parse_streaming_increment(new_text, tools)
-
- def structure_info(self) -> callable:
- """
- Return a function that creates StructureInfo for constrained generation.
- This is not used for JSON schema constraints as they are handled
- by the constraint backends directly.
- """
- raise NotImplementedError("structure_info not used for JSON schema constraints")
diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py
index 898e13b13..c4da456f3 100644
--- a/python/sglang/srt/function_call/utils.py
+++ b/python/sglang/srt/function_call/utils.py
@@ -1,13 +1,10 @@
import json
from json import JSONDecodeError, JSONDecoder
-from json.decoder import WHITESPACE
-from typing import Any, List, Literal, Optional, Tuple, Union
+from typing import Any, Tuple
import partial_json_parser
from partial_json_parser.core.options import Allow
-from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
-
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
"""
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
- except (JSONDecodeError, IndexError) as e:
- msg = getattr(e, "msg", str(e))
- if "Extra data" in msg or "pop from empty list" in msg:
- start = WHITESPACE.match(input_str, 0).end()
- obj, end = JSONDecoder().raw_decode(input_str, start)
- return obj, end
+ except JSONDecodeError as e:
+ if "Extra data" in e.msg:
+ dec = JSONDecoder()
+ return dec.raw_decode(input_str)
raise
@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
return True
except JSONDecodeError:
return False
-
-
-def _get_tool_schema_defs(tools: List[Tool]) -> dict:
- """
- Get consolidated $defs from all tools, validating for conflicts.
-
- Args:
- tools: List of tools to process
-
- Returns:
- Dictionary of consolidated $defs from all tools
-
- Raises:
- ValueError: If conflicting $defs are found
- """
- all_defs = {}
- for tool in tools:
- if tool.function.parameters is None:
- continue
- defs = tool.function.parameters.get("$defs", {})
- for def_name, def_schema in defs.items():
- if def_name in all_defs and all_defs[def_name] != def_schema:
- raise ValueError(
- f"Tool definition '{def_name}' has "
- "multiple schemas, which is not "
- "supported."
- )
- else:
- all_defs[def_name] = def_schema
- return all_defs
-
-
-def _get_tool_schema(tool: Tool) -> dict:
- return {
- "properties": {
- "name": {"type": "string", "enum": [tool.function.name]},
- "parameters": (
- tool.function.parameters
- if tool.function.parameters
- else {"type": "object", "properties": {}}
- ),
- },
- "required": ["name", "parameters"],
- }
-
-
-def get_json_schema_constraint(
- tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
-) -> Optional[dict]:
- """
- Get the JSON schema constraint for the specified tool choice.
-
- Args:
- tool_choice: The tool choice specification
-
- Returns:
- JSON schema dict, or None if no valid tools found
- """
-
- if isinstance(tool_choice, ToolChoice):
- # For specific function choice, return the user's parameters schema directly
- fn_name = tool_choice.function.name
- for tool in tools:
- if tool.function.name == fn_name:
- return {
- "type": "array",
- "minItems": 1,
- "maxItems": 1,
- "items": _get_tool_schema(tool),
- }
- return None
- elif tool_choice == "required":
- json_schema = {
- "type": "array",
- "minItems": 1,
- "items": {
- "type": "object",
- "anyOf": [_get_tool_schema(tool) for tool in tools],
- },
- }
- json_schema_defs = _get_tool_schema_defs(tools)
- if json_schema_defs:
- json_schema["$defs"] = json_schema_defs
- return json_schema
-
- return None
diff --git a/python/sglang/srt/grpc/sglang_scheduler.proto b/python/sglang/srt/grpc/sglang_scheduler.proto
index be6508b5a..e4c87925e 100644
--- a/python/sglang/srt/grpc/sglang_scheduler.proto
+++ b/python/sglang/srt/grpc/sglang_scheduler.proto
@@ -36,9 +36,9 @@ message SamplingParams {
float presence_penalty = 6;
float repetition_penalty = 7;
- optional int32 max_new_tokens = 8;
+ int32 max_new_tokens = 8;
repeated string stop = 9;
- repeated uint32 stop_token_ids = 10;
+ repeated int32 stop_token_ids = 10;
bool skip_special_tokens = 11;
bool spaces_between_special_tokens = 12;
@@ -47,24 +47,24 @@ message SamplingParams {
string regex = 13;
string json_schema = 14;
string ebnf_grammar = 15;
- string structural_tag = 16;
}
// LoRA adapter
- string lora_path = 17;
+ string lora_path = 16;
// Speculative decoding
- int32 n = 18; // Number of samples
+ int32 n = 17; // Number of samples
// Token healing
- bool token_healing = 19;
+ bool token_healing = 18;
// Additional parameters
- int32 min_new_tokens = 20;
- bool ignore_eos = 21;
- bool no_stop_trim = 22;
- int32 stream_interval = 23;
- map logit_bias = 24;
+ int32 min_new_tokens = 19;
+ bool ignore_eos = 20;
+ bool no_stop_trim = 21;
+ int32 stream_interval = 22;
+ map logit_bias = 23;
+ string structural_tag = 24;
// Custom parameters for extensibility
google.protobuf.Struct custom_params = 25;
@@ -98,7 +98,7 @@ message GenerateRequest {
bool return_logprob = 5;
int32 logprob_start_len = 6;
int32 top_logprobs_num = 7;
- repeated uint32 token_ids_logprob = 8;
+ repeated int32 token_ids_logprob = 8;
bool return_hidden_states = 9;
// For disaggregated serving
@@ -122,14 +122,11 @@ message GenerateRequest {
// For load balancing
int32 dp_balance_id = 17;
-
- // Whether client wants streaming response
- bool stream = 18;
}
message TokenizedInput {
string original_text = 1; // For reference
- repeated uint32 input_ids = 2;
+ repeated int32 input_ids = 2;
}
message MultimodalInputs {
@@ -166,50 +163,51 @@ message GenerateResponse {
}
message GenerateStreamChunk {
- // Generated tokens (incremental chunk)
- repeated uint32 token_ids = 1;
+ // Generated token
+ int32 token_id = 1;
+ string text = 2;
// Cumulative counts
- int32 prompt_tokens = 2;
- int32 completion_tokens = 3;
- int32 cached_tokens = 4;
-
- // Output logprobs (if requested) - incremental for streaming
- LogProbs output_logprobs = 5;
-
- // Hidden states (if requested)
- repeated float hidden_states = 6;
-
- // Input logprobs (if requested) - only in first chunk
- LogProbs input_logprobs = 7;
-}
-
-message GenerateComplete {
- // Final output
- repeated uint32 output_ids = 1;
-
- // Finish reason as OpenAI-compatible string ("stop", "length", "abort")
- string finish_reason = 2;
-
- // Token usage counts
int32 prompt_tokens = 3;
int32 completion_tokens = 4;
int32 cached_tokens = 5;
- // Output logprobs if requested (cumulative)
- LogProbs output_logprobs = 6;
+ // Logprobs (if requested)
+ LogProbs logprobs = 6;
+
+ // Hidden states (if requested)
+ repeated float hidden_states = 7;
+
+ // Metadata
+ float generation_time = 8; // Time to generate this token
+ int32 queue_time = 9; // Time spent in queue
+}
+
+message GenerateComplete {
+ // Final output
+ repeated int32 output_ids = 1;
+ string output_text = 2;
+
+ // Finish reason
+ enum FinishReason {
+ // The model generated a stop sequence.
+ STOP = 0;
+ // The model reached the maximum generation length.
+ LENGTH = 1;
+ // The model generated an end-of-sequence (EOS) token.
+ EOS_TOKEN = 2;
+ // The model generated a user-provided stop string.
+ STOP_STR = 3;
+ // The request was aborted by the user or system.
+ ABORT = 4;
+ }
+ FinishReason finish_reason = 3;
+
+ // All logprobs if requested
+ repeated LogProbs all_logprobs = 11;
// All hidden states if requested
- repeated HiddenStates all_hidden_states = 7;
-
- // Matched stop information (for stop sequences)
- oneof matched_stop {
- uint32 matched_token_id = 8;
- string matched_stop_str = 9;
- }
-
- // Input logprobs if requested (for prompt tokens)
- LogProbs input_logprobs = 10;
+ repeated HiddenStates all_hidden_states = 12;
}
message GenerateError {
@@ -224,11 +222,15 @@ message LogProbs {
// Top logprobs at each position
repeated TopLogProbs top_logprobs = 3;
+
+ // Decoded text for tokens
+ repeated string token_texts = 4;
}
message TopLogProbs {
repeated float values = 1;
repeated int32 token_ids = 2;
+ repeated string token_texts = 3;
}
message HiddenStates {
@@ -283,9 +285,10 @@ message EmbedComplete {
// Additional metadata
int32 embedding_dim = 4;
+ float generation_time = 5;
// For batch embeddings
- repeated Embedding batch_embeddings = 5;
+ repeated Embedding batch_embeddings = 6;
}
message Embedding {
diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py
index 2f80f83bb..4b288d768 100644
--- a/python/sglang/srt/grpc/sglang_scheduler_pb2.py
+++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.py
@@ -1,6 +1,3 @@
-# This file is auto-generated. Do not edit manually.
-# Regenerate with: python compile_proto.py
-
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
@@ -29,7 +26,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xf9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\x12\x0e\n\x06stream\x18\x12 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xfb\x01\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12\x38\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12\x37\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\"\x81\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x38\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12\x37\n\x0einput_logprobs\x18\n \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbsB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"o\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc7\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x16\n\x0emax_new_tokens\x18\x08 \x01(\x05\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\x05\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x11\n\tlora_path\x18\x10 \x01(\t\x12\t\n\x01n\x18\x11 \x01(\x05\x12\x15\n\rtoken_healing\x18\x12 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x13 \x01(\x05\x12\x12\n\nignore_eos\x18\x14 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x15 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x16 \x01(\x05\x12H\n\nlogit_bias\x18\x17 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12\x16\n\x0estructural_tag\x18\x18 \x01(\t\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraint\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\x05\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\x05\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xf5\x01\n\x13GenerateStreamChunk\x12\x10\n\x08token_id\x18\x01 \x01(\x05\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x31\n\x08logprobs\x18\x06 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x07 \x03(\x02\x12\x17\n\x0fgeneration_time\x18\x08 \x01(\x02\x12\x12\n\nqueue_time\x18\t \x01(\x05\"\xcd\x02\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\x05\x12\x13\n\x0boutput_text\x18\x02 \x01(\t\x12K\n\rfinish_reason\x18\x03 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x35\n\x0c\x61ll_logprobs\x18\x0b \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x0c \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xbc\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12\x17\n\x0fgeneration_time\x18\x05 \x01(\x02\x12:\n\x10\x62\x61tch_embeddings\x18\x06 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -39,69 +36,71 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._loaded_options = None
_globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_options = b'8\001'
_globals['_SAMPLINGPARAMS']._serialized_start=113
- _globals['_SAMPLINGPARAMS']._serialized_end=850
- _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=769
- _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=817
- _globals['_DISAGGREGATEDPARAMS']._serialized_start=852
- _globals['_DISAGGREGATEDPARAMS']._serialized_end=945
- _globals['_GENERATEREQUEST']._serialized_start=948
- _globals['_GENERATEREQUEST']._serialized_end=1581
- _globals['_TOKENIZEDINPUT']._serialized_start=1583
- _globals['_TOKENIZEDINPUT']._serialized_end=1641
- _globals['_MULTIMODALINPUTS']._serialized_start=1644
- _globals['_MULTIMODALINPUTS']._serialized_end=1855
- _globals['_GENERATERESPONSE']._serialized_start=1858
- _globals['_GENERATERESPONSE']._serialized_end=2085
- _globals['_GENERATESTREAMCHUNK']._serialized_start=2088
- _globals['_GENERATESTREAMCHUNK']._serialized_end=2339
- _globals['_GENERATECOMPLETE']._serialized_start=2342
- _globals['_GENERATECOMPLETE']._serialized_end=2727
- _globals['_GENERATEERROR']._serialized_start=2729
- _globals['_GENERATEERROR']._serialized_end=2804
- _globals['_LOGPROBS']._serialized_start=2806
- _globals['_LOGPROBS']._serialized_end=2917
- _globals['_TOPLOGPROBS']._serialized_start=2919
- _globals['_TOPLOGPROBS']._serialized_end=2967
- _globals['_HIDDENSTATES']._serialized_start=2969
- _globals['_HIDDENSTATES']._serialized_end=3032
- _globals['_EMBEDREQUEST']._serialized_start=3035
- _globals['_EMBEDREQUEST']._serialized_end=3365
- _globals['_EMBEDRESPONSE']._serialized_start=3368
- _globals['_EMBEDRESPONSE']._serialized_end=3525
- _globals['_EMBEDCOMPLETE']._serialized_start=3528
- _globals['_EMBEDCOMPLETE']._serialized_end=3691
- _globals['_EMBEDDING']._serialized_start=3693
- _globals['_EMBEDDING']._serialized_end=3735
- _globals['_EMBEDERROR']._serialized_start=3737
- _globals['_EMBEDERROR']._serialized_end=3797
- _globals['_HEALTHCHECKREQUEST']._serialized_start=3799
- _globals['_HEALTHCHECKREQUEST']._serialized_end=3877
- _globals['_HEALTHCHECKRESPONSE']._serialized_start=3879
- _globals['_HEALTHCHECKRESPONSE']._serialized_end=3934
- _globals['_ABORTREQUEST']._serialized_start=3936
- _globals['_ABORTREQUEST']._serialized_end=3986
- _globals['_ABORTRESPONSE']._serialized_start=3988
- _globals['_ABORTRESPONSE']._serialized_end=4037
- _globals['_LOADLORAREQUEST']._serialized_start=4039
- _globals['_LOADLORAREQUEST']._serialized_end=4112
- _globals['_LOADLORARESPONSE']._serialized_start=4114
- _globals['_LOADLORARESPONSE']._serialized_end=4186
- _globals['_UNLOADLORAREQUEST']._serialized_start=4188
- _globals['_UNLOADLORAREQUEST']._serialized_end=4227
- _globals['_UNLOADLORARESPONSE']._serialized_start=4229
- _globals['_UNLOADLORARESPONSE']._serialized_end=4283
- _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4285
- _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4404
- _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4406
- _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4463
- _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4465
- _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4510
- _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4512
- _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4578
- _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4580
- _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4645
- _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4647
- _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4707
- _globals['_SGLANGSCHEDULER']._serialized_start=4710
- _globals['_SGLANGSCHEDULER']._serialized_end=5092
+ _globals['_SAMPLINGPARAMS']._serialized_end=824
+ _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=762
+ _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=810
+ _globals['_DISAGGREGATEDPARAMS']._serialized_start=826
+ _globals['_DISAGGREGATEDPARAMS']._serialized_end=919
+ _globals['_GENERATEREQUEST']._serialized_start=922
+ _globals['_GENERATEREQUEST']._serialized_end=1539
+ _globals['_TOKENIZEDINPUT']._serialized_start=1541
+ _globals['_TOKENIZEDINPUT']._serialized_end=1599
+ _globals['_MULTIMODALINPUTS']._serialized_start=1602
+ _globals['_MULTIMODALINPUTS']._serialized_end=1813
+ _globals['_GENERATERESPONSE']._serialized_start=1816
+ _globals['_GENERATERESPONSE']._serialized_end=2043
+ _globals['_GENERATESTREAMCHUNK']._serialized_start=2046
+ _globals['_GENERATESTREAMCHUNK']._serialized_end=2291
+ _globals['_GENERATECOMPLETE']._serialized_start=2294
+ _globals['_GENERATECOMPLETE']._serialized_end=2627
+ _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2551
+ _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2627
+ _globals['_GENERATEERROR']._serialized_start=2629
+ _globals['_GENERATEERROR']._serialized_end=2704
+ _globals['_LOGPROBS']._serialized_start=2707
+ _globals['_LOGPROBS']._serialized_end=2839
+ _globals['_TOPLOGPROBS']._serialized_start=2841
+ _globals['_TOPLOGPROBS']._serialized_end=2910
+ _globals['_HIDDENSTATES']._serialized_start=2912
+ _globals['_HIDDENSTATES']._serialized_end=2975
+ _globals['_EMBEDREQUEST']._serialized_start=2978
+ _globals['_EMBEDREQUEST']._serialized_end=3308
+ _globals['_EMBEDRESPONSE']._serialized_start=3311
+ _globals['_EMBEDRESPONSE']._serialized_end=3468
+ _globals['_EMBEDCOMPLETE']._serialized_start=3471
+ _globals['_EMBEDCOMPLETE']._serialized_end=3659
+ _globals['_EMBEDDING']._serialized_start=3661
+ _globals['_EMBEDDING']._serialized_end=3703
+ _globals['_EMBEDERROR']._serialized_start=3705
+ _globals['_EMBEDERROR']._serialized_end=3765
+ _globals['_HEALTHCHECKREQUEST']._serialized_start=3767
+ _globals['_HEALTHCHECKREQUEST']._serialized_end=3845
+ _globals['_HEALTHCHECKRESPONSE']._serialized_start=3847
+ _globals['_HEALTHCHECKRESPONSE']._serialized_end=3902
+ _globals['_ABORTREQUEST']._serialized_start=3904
+ _globals['_ABORTREQUEST']._serialized_end=3954
+ _globals['_ABORTRESPONSE']._serialized_start=3956
+ _globals['_ABORTRESPONSE']._serialized_end=4005
+ _globals['_LOADLORAREQUEST']._serialized_start=4007
+ _globals['_LOADLORAREQUEST']._serialized_end=4080
+ _globals['_LOADLORARESPONSE']._serialized_start=4082
+ _globals['_LOADLORARESPONSE']._serialized_end=4154
+ _globals['_UNLOADLORAREQUEST']._serialized_start=4156
+ _globals['_UNLOADLORAREQUEST']._serialized_end=4195
+ _globals['_UNLOADLORARESPONSE']._serialized_start=4197
+ _globals['_UNLOADLORARESPONSE']._serialized_end=4251
+ _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4253
+ _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4372
+ _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4374
+ _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4431
+ _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4433
+ _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4478
+ _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4480
+ _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4546
+ _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4548
+ _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4613
+ _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4615
+ _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4675
+ _globals['_SGLANGSCHEDULER']._serialized_start=4678
+ _globals['_SGLANGSCHEDULER']._serialized_end=5060
# @@protoc_insertion_point(module_scope)
diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
index 3578abe74..d9388463d 100644
--- a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
+++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
@@ -3,6 +3,7 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
+from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message):
- __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
+ __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
- STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
+ STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float
top_p: float
@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
regex: str
json_schema: str
ebnf_grammar: str
- structural_tag: str
lora_path: str
n: int
token_healing: bool
@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
no_stop_trim: bool
stream_interval: int
logit_bias: _containers.ScalarMap[str, float]
+ structural_tag: str
custom_params: _struct_pb2.Struct
- def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
+ def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
- __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
+ __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
- STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
lora_id: str
data_parallel_rank: int
dp_balance_id: int
- stream: bool
- def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
+ def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
- __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
- TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
+ __slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
+ TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
+ TEXT_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
- OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
+ LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
- INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
- token_ids: _containers.RepeatedScalarFieldContainer[int]
+ GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
+ QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
+ token_id: int
+ text: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
- output_logprobs: LogProbs
+ logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
- input_logprobs: LogProbs
- def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
+ generation_time: float
+ queue_time: int
+ def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
- __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
+ __slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
+ class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
+ __slots__ = ()
+ STOP: _ClassVar[GenerateComplete.FinishReason]
+ LENGTH: _ClassVar[GenerateComplete.FinishReason]
+ EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
+ STOP_STR: _ClassVar[GenerateComplete.FinishReason]
+ ABORT: _ClassVar[GenerateComplete.FinishReason]
+ STOP: GenerateComplete.FinishReason
+ LENGTH: GenerateComplete.FinishReason
+ EOS_TOKEN: GenerateComplete.FinishReason
+ STOP_STR: GenerateComplete.FinishReason
+ ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
+ OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
- PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
- COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
- CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
- OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
+ ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
- MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
- MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
- INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
- finish_reason: str
- prompt_tokens: int
- completion_tokens: int
- cached_tokens: int
- output_logprobs: LogProbs
+ output_text: str
+ finish_reason: GenerateComplete.FinishReason
+ all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
- matched_token_id: int
- matched_stop_str: str
- input_logprobs: LogProbs
- def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
+ def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
- __slots__ = ("token_logprobs", "token_ids", "top_logprobs")
+ __slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
+ TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
- def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
+ token_texts: _containers.RepeatedScalarFieldContainer[str]
+ def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class TopLogProbs(_message.Message):
- __slots__ = ("values", "token_ids")
+ __slots__ = ("values", "token_ids", "token_texts")
VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
+ TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
- def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
+ token_texts: _containers.RepeatedScalarFieldContainer[str]
+ def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position")
@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
class EmbedComplete(_message.Message):
- __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
+ __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
+ GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
embedding: _containers.RepeatedScalarFieldContainer[float]
prompt_tokens: int
cached_tokens: int
embedding_dim: int
+ generation_time: float
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
- def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
+ def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message):
__slots__ = ("values", "index")
diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py b/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
index 402f71725..d9bdf0462 100644
--- a/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
+++ b/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
@@ -1,6 +1,3 @@
-# This file is auto-generated. Do not edit manually.
-# Regenerate with: python compile_proto.py
-
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py
index 89c5b63f6..85eb77abd 100644
--- a/python/sglang/srt/hf_transformers_utils.py
+++ b/python/sglang/srt/hf_transformers_utils.py
@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
return config
+def _load_deepseek_v32_model(
+ model_path: str,
+ trust_remote_code: bool = False,
+ revision: Optional[str] = None,
+ **kwargs,
+):
+ # first get the local path
+ local_path = download_from_hf(model_path)
+ # then load the config file in json
+ config_file = os.path.join(local_path, "config.json")
+ if not os.path.exists(config_file):
+ raise RuntimeError(f"Can't find config file in {local_path}.")
+
+ with open(config_file, "r") as f:
+ config_json = json.load(f)
+
+ config_json["architectures"] = ["DeepseekV3ForCausalLM"]
+ config_json["model_type"] = "deepseek_v3"
+
+ tmp_path = os.path.join(local_path, "_tmp_config_folder")
+ os.makedirs(tmp_path, exist_ok=True)
+
+ unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
+ with open(unique_path, "w") as f:
+ json.dump(config_json, f)
+
+ return AutoConfig.from_pretrained(
+ unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
+ )
+
+
@lru_cache_frozenset(maxsize=32)
def get_config(
model: str,
@@ -140,9 +171,17 @@ def get_config(
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model = client.get_local_dir()
- config = AutoConfig.from_pretrained(
- model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
- )
+ try:
+ config = AutoConfig.from_pretrained(
+ model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
+ )
+ except ValueError as e:
+ if not "deepseek_v32" in str(e):
+ raise e
+ config = _load_deepseek_v32_model(
+ model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
+ )
+
if (
config.architectures is not None
and config.architectures[0] == "Phi4MMForCausalLM"
diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py
index f1b2da5f8..188d772c7 100644
--- a/python/sglang/srt/layers/attention/aiter_backend.py
+++ b/python/sglang/srt/layers/attention/aiter_backend.py
@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
assert len(k.shape) == 3
assert len(v.shape) == 3
- if (
- forward_batch.forward_mode.is_extend()
- and not forward_batch.forward_mode.is_target_verify()
- and not forward_batch.forward_mode.is_draft_extend()
- ):
+ if forward_batch.forward_mode.is_extend():
if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func(
q,
diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py
index 52192b7bc..e9be57599 100644
--- a/python/sglang/srt/layers/attention/ascend_backend.py
+++ b/python/sglang/srt/layers/attention/ascend_backend.py
@@ -3,6 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
+import custom_ops
import torch
import torch_npu
from torch.nn.functional import scaled_dot_product_attention
@@ -36,6 +37,8 @@ class ForwardMetadata:
seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None
seq_lens_list_cumsum: Optional[List[int]] = None
+ seq_lens: Optional[torch.Tensor] = None
+ actual_seq_lengths_q: Optional[torch.Tensor] = None
class AscendAttnBackend(AttentionBackend):
@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
+ self.q_head_dim = (
+ self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
+ )
self.native_attn = TorchNativeAttnBackend(model_runner)
self.graph_metadata = {}
self.max_context_len = model_runner.model_config.context_len
@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
- if forward_batch.is_extend_in_batch:
- seq_lens_list_cumsum[-1] = (
- (seq_lens_list_cumsum[-1] - 1) // tp_size + 1
- ) * tp_size
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
self.graph_mode = False
@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
+ metadata.seq_lens = seq_lens
+ metadata.actual_seq_lengths_q = torch.tensor(
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
+ )
self.graph_metadata[bs] = metadata
self.forward_metadata = metadata
@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
metadata.block_tables[bs:, :].fill_(0)
+ metadata.seq_lens[:bs].copy_(seq_lens[:bs])
+
self.forward_metadata = metadata
self.graph_mode = True
@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 0
+ def forward_sparse(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ layer: RadixAttention,
+ forward_batch: ForwardBatch,
+ save_kv_cache: bool = True,
+ # For multi_head latent attention
+ q_rope: Optional[torch.Tensor] = None,
+ k_rope: Optional[torch.Tensor] = None,
+ topk_indices: torch.Tensor = None,
+ ):
+
+ is_prefill = forward_batch.forward_mode.is_extend()
+
+ if save_kv_cache:
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
+ forward_batch.token_to_kv_pool.set_kv_buffer(
+ layer, forward_batch.out_cache_loc, k, k_rope
+ )
+ q_nope, q_pe = q, q_rope
+ k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
+ block_table = self.forward_metadata.block_tables
+ if is_prefill:
+ actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
+ else:
+ if self.forward_metadata.actual_seq_lengths_q is None:
+ actual_seq_qlen = (
+ torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
+ )
+ else:
+ actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
+ if self.forward_metadata.seq_lens_cpu_int is None:
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens
+ else:
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
+
+ attn_out = torch.ops.custom.npu_sparse_flash_attention(
+ query=q_nope,
+ key=k_nope,
+ value=k_nope,
+ query_rope=q_pe,
+ key_rope=k_pe,
+ sparse_indices=topk_indices,
+ scale_value=layer.scaling,
+ actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
+ actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
+ block_table=block_table,
+ sparse_block_size=1,
+ layout_query="TND",
+ layout_kv="PA_BSND",
+ sparse_mode=3,
+ )
+
+ return attn_out
+
def forward_extend(
self,
q,
@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
+ # For multi_head latent attention
+ q_rope: Optional[torch.Tensor] = None,
+ k_rope: Optional[torch.Tensor] = None,
+ topk_indices: Optional[torch.Tensor] = None,
):
+ if topk_indices is not None:
+ return self.forward_sparse(
+ q,
+ k,
+ v,
+ layer,
+ forward_batch,
+ save_kv_cache,
+ q_rope,
+ k_rope,
+ topk_indices,
+ )
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
+ topk_indices: Optional[torch.Tensor] = None,
):
if is_mla_preprocess_enabled():
# MLAPO does saving kv_cache
save_kv_cache = False
+ if topk_indices is not None:
+ return self.forward_sparse(
+ q,
+ k,
+ v,
+ layer,
+ forward_batch,
+ save_kv_cache,
+ q_rope,
+ k_rope,
+ topk_indices,
+ )
if self.graph_mode:
return self.forward_decode_graph(
diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py
index 658ad1f0f..cd023b4b4 100644
--- a/python/sglang/srt/layers/attention/attention_registry.py
+++ b/python/sglang/srt/layers/attention/attention_registry.py
@@ -1,7 +1,3 @@
-import logging
-
-logger = logging.getLogger(__name__)
-
ATTENTION_BACKENDS = {}
@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
return AscendAttnBackend(runner)
+@register_attention_backend("nsa")
+def create_nsa_backend(runner):
+ from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
+
+ return NativeSparseAttnBackend(runner)
+
+
@register_attention_backend("triton")
def create_triton_backend(runner):
assert not runner.model_config.is_encoder_decoder, (
@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner)
-def attn_backend_wrapper(runner, full_attn_backend):
- """
- Wrapper for special models like hybrid GDN, so we don't
- need to change the code of the original attention backend.
- """
- assert not (
- runner.is_hybrid_gdn and runner.use_mla_backend
- ), "hybrid_gdn can only be used with non-MLA models."
+@register_attention_backend("hybrid_linear_attn")
+def create_hybrid_linear_attn_backend(runner):
+ assert (
+ runner.is_hybrid_gdn
+ ), "hybrid_linear_attn backend can only be used with hybrid GDN models."
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
+ HybridLinearAttnBackend,
+ MambaAttnBackend,
+ )
+ from sglang.srt.utils import is_blackwell, is_npu
- # wrap for hybrid GDN models
- if runner.is_hybrid_gdn:
- from sglang.srt.utils import is_blackwell, is_npu
+ if is_npu():
+ from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
- if is_blackwell():
- assert (
- runner.server_args.attention_backend == "triton"
- ), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
- if is_npu():
- assert (
- runner.server_args.attention_backend == "ascend"
- ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
- logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
- from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
- HybridLinearAttnBackend,
- MambaAttnBackend,
+ full_attn_backend = AscendAttnBackend(runner)
+ elif is_blackwell():
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
+
+ full_attn_backend = TritonAttnBackend(runner)
+ else:
+ from sglang.srt.layers.attention.flashattention_backend import (
+ FlashAttentionBackend,
)
- linear_attn_backend = MambaAttnBackend(runner)
- full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
- return HybridLinearAttnBackend(
- full_attn_backend, linear_attn_backend, full_attn_layers
- )
+ full_attn_backend = FlashAttentionBackend(runner)
- return full_attn_backend
+ linear_attn_backend = MambaAttnBackend(runner)
+ full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
+
+ return HybridLinearAttnBackend(
+ full_attn_backend, linear_attn_backend, full_attn_layers
+ )
diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py
index 3025d0b11..b3482cc98 100644
--- a/python/sglang/srt/layers/attention/base_attn_backend.py
+++ b/python/sglang/srt/layers/attention/base_attn_backend.py
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
if TYPE_CHECKING:
+ from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
def support_triton(self):
"""Check if the current backend supports triton."""
return True
+
+ def get_indexer_metadata(
+ self,
+ layer_id: int,
+ forward_batch: ForwardBatch,
+ ) -> Optional[BaseIndexerMetadata]:
+ """Get the indexer metadata. None means don't support indexer."""
+ return None
diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py
index 67cad8d23..be7fed8de 100644
--- a/python/sglang/srt/layers/attention/flashattention_backend.py
+++ b/python/sglang/srt/layers/attention/flashattention_backend.py
@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
- # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
- # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
- if (
- self.kv_cache_dtype_str != "auto"
- and layer.head_dim <= 256
- and self.fa_impl_ver != 4
- ):
+ # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
+ if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py
index 2b69d734c..aaa8b520b 100644
--- a/python/sglang/srt/layers/attention/flashinfer_backend.py
+++ b/python/sglang/srt/layers/attention/flashinfer_backend.py
@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
-from sglang.srt.speculative.ngram_utils import NgramVerifyInput
+from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.utils import (
get_int_env_var,
is_flashinfer_available,
@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
fixed_split_size: Optional[int] = None,
):
# Keep the signature for type checking. It will be assigned during runtime.
@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
fixed_split_size: Optional[int] = None,
):
if use_ragged:
@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
):
@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask = None
else:
assert isinstance(
- spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
+ spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py
index d1acb1a58..52ae480b3 100644
--- a/python/sglang/srt/layers/attention/flashmla_backend.py
+++ b/python/sglang/srt/layers/attention/flashmla_backend.py
@@ -5,13 +5,20 @@ Support attention backend for FlashMLA.
"""
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union
import torch
import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
+from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
+from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
+from sglang.srt.layers.attention.nsa.utils import (
+ NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
+ NSA_KV_CACHE_STORE_FP8,
+ compute_nsa_seqlens,
+)
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -74,10 +81,17 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
- self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
+ self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
+ self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
+ self.nsa_index_topk = (
+ get_nsa_index_topk(model_runner.model_config.hf_config)
+ if self.use_nsa
+ else None
+ )
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
@@ -100,10 +114,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
max_seqlen_pad,
)
- mla_metadata, num_splits = get_mla_metadata(
- forward_batch.seq_lens.to(torch.int32),
- self.num_q_heads,
- 1,
+ mla_metadata, num_splits = _get_mla_metadata_wrapped(
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32),
+ seq_len_q=1,
+ num_heads_q=self.num_q_heads,
+ num_heads_k=1,
+ nsa_index_topk=self.nsa_index_topk,
)
self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata,
@@ -130,10 +146,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
max_seqlen_pad,
)
- mla_metadata, num_splits = get_mla_metadata(
- seq_lens.to(torch.int32),
- self.num_draft_tokens * self.num_q_heads,
- 1,
+ mla_metadata, num_splits = _get_mla_metadata_wrapped(
+ cache_seqlens=seq_lens.to(torch.int32),
+ seq_len_q=self.num_draft_tokens,
+ num_heads_q=self.num_q_heads,
+ num_heads_k=1,
+ nsa_index_topk=self.nsa_index_topk,
)
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
@@ -162,20 +180,28 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
cuda_graph_kv_indices = block_kv_indices
if self.num_draft_tokens:
- self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
- torch.ones(
- max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
- ),
- self.num_draft_tokens * self.num_q_heads,
- 1,
+ self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
+ _get_mla_metadata_wrapped(
+ cache_seqlens=torch.ones(
+ max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
+ ),
+ seq_len_q=self.num_draft_tokens,
+ num_heads_q=self.num_q_heads,
+ num_heads_k=1,
+ nsa_index_topk=self.nsa_index_topk,
+ )
)
else:
- self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
- torch.ones(
- max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
- ),
- self.num_q_heads,
- 1,
+ self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
+ _get_mla_metadata_wrapped(
+ cache_seqlens=torch.ones(
+ max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
+ ),
+ seq_len_q=1,
+ num_heads_q=self.num_q_heads,
+ num_heads_k=1,
+ nsa_index_topk=self.nsa_index_topk,
+ )
)
self.cuda_graph_kv_indices = cuda_graph_kv_indices
@@ -201,10 +227,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
- mla_metadata, num_splits = get_mla_metadata(
- seq_lens.to(torch.int32),
- self.num_q_heads,
- 1,
+ mla_metadata, num_splits = _get_mla_metadata_wrapped(
+ cache_seqlens=seq_lens.to(torch.int32),
+ seq_len_q=1,
+ num_heads_q=self.num_q_heads,
+ num_heads_k=1,
+ nsa_index_topk=self.nsa_index_topk,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -226,10 +254,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
- mla_metadata, num_splits = get_mla_metadata(
- seq_lens.to(torch.int32),
- self.num_draft_tokens * self.num_q_heads,
- 1,
+ mla_metadata, num_splits = _get_mla_metadata_wrapped(
+ cache_seqlens=seq_lens.to(torch.int32),
+ seq_len_q=self.num_draft_tokens,
+ num_heads_q=self.num_q_heads,
+ num_heads_k=1,
+ nsa_index_topk=self.nsa_index_topk,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -275,10 +305,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
- mla_metadata, num_splits = get_mla_metadata(
- seq_lens.to(torch.int32),
- self.num_q_heads,
- 1,
+ mla_metadata, num_splits = _get_mla_metadata_wrapped(
+ cache_seqlens=seq_lens.to(torch.int32),
+ seq_len_q=1,
+ num_heads_q=self.num_q_heads,
+ num_heads_k=1,
+ nsa_index_topk=self.nsa_index_topk,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -300,10 +332,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
- mla_metadata, num_splits = get_mla_metadata(
- seq_lens.to(torch.int32),
- self.num_draft_tokens * self.num_q_heads,
- 1,
+ mla_metadata, num_splits = _get_mla_metadata_wrapped(
+ cache_seqlens=seq_lens.to(torch.int32),
+ seq_len_q=self.num_draft_tokens,
+ num_heads_q=self.num_q_heads,
+ num_heads_k=1,
+ nsa_index_topk=self.nsa_index_topk,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -335,6 +369,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
+ topk_indices: Optional[torch.Tensor] = None,
):
cache_loc = forward_batch.out_cache_loc
@@ -349,13 +384,14 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
+ k_cache = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
- if self.data_type == torch.float8_e4m3fn:
+ if (not self.use_nsa) and self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
o, _ = flash_mla_with_kvcache(
q=reshape_q_fp8,
- k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
+ k_cache=k_cache,
block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
@@ -369,17 +405,49 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
+ block_table = self.forward_metadata.block_kv_indices[:bs]
+ cache_seqlens = forward_batch.seq_lens.to(torch.int32)
+
+ extra_kwargs: Dict
+ if self.use_nsa:
+ assert topk_indices is not None
+ extra_kwargs = dict(
+ indices=_compute_indices_in_kvcache(
+ block_table=block_table,
+ topk_indices=topk_indices.to(torch.int32),
+ page_size=self.page_size,
+ ),
+ # doc says it is not used, but if pass in None then error
+ block_table=block_table,
+ is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
+ )
+ cache_seqlens = compute_nsa_seqlens(
+ cache_seqlens, nsa_index_topk=self.nsa_index_topk
+ )
+ else:
+ extra_kwargs = dict(
+ block_table=block_table,
+ causal=True,
+ )
+
+ if (
+ self.use_nsa
+ and NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
+ and not NSA_KV_CACHE_STORE_FP8
+ ):
+ # inefficiently quantize the whole cache
+ k_cache = quantize_k_cache(k_cache)
+
# todo: need check all causal True or False?
o, _ = flash_mla_with_kvcache(
q=reshape_q,
- k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
- block_table=self.forward_metadata.block_kv_indices[:bs],
- cache_seqlens=forward_batch.seq_lens.to(torch.int32),
+ k_cache=k_cache,
+ cache_seqlens=cache_seqlens,
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
- causal=True,
+ **extra_kwargs,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -539,3 +607,52 @@ class FlashMLAMultiStepDraftBackend:
)
self.common_template(forward_batch, call_fn)
+
+
+def _get_mla_metadata_wrapped(
+ *,
+ cache_seqlens: torch.Tensor,
+ seq_len_q: int,
+ num_heads_q: int,
+ num_heads_k: int,
+ nsa_index_topk: Optional[int],
+):
+ if nsa_index_topk is not None:
+ assert nsa_index_topk is not None
+ return get_mla_metadata(
+ cache_seqlens=cache_seqlens,
+ # TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
+ # but the name looks like need seq_len_q?
+ num_q_tokens_per_head_k=seq_len_q * num_heads_q // num_heads_k,
+ num_heads_k=num_heads_k,
+ num_heads_q=num_heads_q,
+ is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
+ topk=nsa_index_topk,
+ )
+ else:
+ assert nsa_index_topk is None
+ return get_mla_metadata(
+ cache_seqlens=cache_seqlens,
+ num_heads_per_head_k=seq_len_q * num_heads_q // num_heads_k,
+ num_heads_k=num_heads_k,
+ )
+
+
+# TODO speedup
+def _compute_indices_in_kvcache(block_table, topk_indices, page_size):
+ topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
+
+ idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
+ 1
+ )
+ block_idx = block_table[idx0, topk_indices_safe // page_size]
+ offset = topk_indices_safe % page_size
+ indices_in_kvcache = block_idx * page_size + offset
+
+ # the kernel requires invalid entry to be -1
+ assert indices_in_kvcache.shape == topk_indices.shape
+ indices_in_kvcache[topk_indices == -1] = -1
+
+ # return: (batch_size, seqlen_q_ori, topk)
+ indices_in_kvcache = indices_in_kvcache[:, None, :]
+ return indices_in_kvcache
diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py
index ec40100d1..37f27bc6a 100644
--- a/python/sglang/srt/layers/attention/hybrid_attn_backend.py
+++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py
@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
+from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
+
+ def get_indexer_metadata(
+ self, layer_id: int, forward_batch: ForwardBatch
+ ) -> Optional[BaseIndexerMetadata]:
+ backend = self._select_backend(forward_batch.forward_mode)
+ return backend.get_indexer_metadata(layer_id, forward_batch)
diff --git a/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py b/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
index 84efe2ce4..06a552545 100644
--- a/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
+++ b/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
self.rotary_emb = rotary_emb
self.layer_id = layer_id
self.has_preprocess_weights = False
+ self.dtype = None
self.q_lora_rank = self.q_b_proj.input_size # 1536
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
self.num_local_heads = num_local_heads # tp
self.qk_nope_head_dim = qk_nope_head_dim # 128
self.qk_rope_head_dim = qk_rope_head_dim # 64
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
def preprocess_weights(self, hidden_states):
self.dummy = torch.empty(
@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
return k_cache, v_cache, slot_mapping
- def forward(self, positions, hidden_states, forward_batch, zero_allocator):
+ def forward_absorb_prepare_npu_rms_norm_cache(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ forward_batch,
+ zero_allocator,
+ ):
+ bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
+ self.dtype = hidden_states.dtype
+ self.cos, self.sin = self.get_sin_cos(positions)
+ self.kvCache, self.kvCacheRope, self.slotmapping = (
+ self.get_kv_cache_and_cache_idx(forward_batch)
+ )
+
+ if not self.has_preprocess_weights:
+ self.has_preprocess_weights = True
+
+ cos, sin = self.cos, self.sin
+
+ if self.q_lora_rank is not None:
+ fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
+ q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
+ )
+ q = self.q_a_layernorm(q_lowrank)
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
+ else:
+ q = self.q_proj(hidden_states)[0].view(
+ -1, self.num_local_heads, self.qk_head_dim
+ )
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
+
+ q_nope, q_pe = torch.split(
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ ) # b*s,n,d
+
+ q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
+ q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
+
+ q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
+ cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
+ sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
+ q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
+ q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
+
+ latent_cache = latent_cache.view(
+ -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
+ ) # (B*S,N,1,D)
+
+ cache_mode = "PA_BNSD"
+ self.kvCache = self.kvCache.view(
+ -1,
+ forward_batch.attn_backend.page_size,
+ 1,
+ forward_batch.attn_backend.kv_lora_rank,
+ )
+ self.kvCacheRope = self.kvCacheRope.view(
+ -1,
+ forward_batch.attn_backend.page_size,
+ 1,
+ forward_batch.attn_backend.qk_rope_head_dim,
+ )
+ k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
+ latent_cache,
+ self.kv_a_layernorm.weight,
+ cos,
+ sin,
+ self.slotmapping.to(torch.int64),
+ self.kvCacheRope,
+ self.kvCache,
+ epsilon=self.kv_a_layernorm.variance_epsilon,
+ cache_mode=cache_mode,
+ )
+
+ return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
+
+ def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
input_dtype = hidden_states.dtype
if not self.has_preprocess_weights:
self.preprocess_weights(hidden_states)
@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
zero_allocator,
positions,
)
+
+ def forward(self, positions, hidden_states, forward_batch, zero_allocator):
+ _is_w8a8 = (
+ hasattr(self.qkv_a_proj.quant_method, "quantization_config")
+ and self.qkv_a_proj.quant_method.quantization_config.get_name()
+ == "w8a8_int8"
+ )
+ if _is_w8a8:
+ return self.forward_mlapo(
+ positions, hidden_states, forward_batch, zero_allocator
+ )
+ else:
+ return self.forward_absorb_prepare_npu_rms_norm_cache(
+ positions, hidden_states, forward_batch, zero_allocator
+ )
diff --git a/python/sglang/srt/layers/attention/nsa/cuda/__init__.py b/python/sglang/srt/layers/attention/nsa/cuda/__init__.py
new file mode 100644
index 000000000..d7e76b78e
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/cuda/__init__.py
@@ -0,0 +1,3 @@
+from .topk import fast_topk, fast_topk_transform
+
+__all__ = ["fast_topk", "fast_topk_transform"]
diff --git a/python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu b/python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
new file mode 100644
index 000000000..d8657ef5e
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
@@ -0,0 +1,505 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace {
+
+constexpr int TopK = 2048;
+constexpr int kThreadsPerBlock = 1024;
+constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
+
+struct FastTopKParams {
+ const float *__restrict__ input; // [B, input_stride]
+ int32_t *__restrict__ indices; // [B, TopK]
+ int32_t *__restrict__ lengths; // [B]
+ int64_t input_stride;
+ bool use_tilelang;
+};
+
+// when length <= TopK, we can directly write the indices
+__device__ void naive_topk_cuda(const float *__restrict__ score,
+ int32_t *__restrict__ indice, int32_t length) {
+ const auto tid = threadIdx.x;
+ for (int i = tid; i < TopK; i += kThreadsPerBlock) {
+ indice[i] = (i < length) ? i : -1;
+ }
+}
+
+// keep the first `length` entries, set others to -1
+__device__ void
+naive_topk_transform(const float *__restrict__ score, int32_t length,
+ int32_t *__restrict__ dst_page_table,
+ const int32_t *__restrict__ src_page_table) {
+ const auto tid = threadIdx.x;
+ for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
+ dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
+ }
+}
+
+__device__ __forceinline__ uint8_t convert_to_uint8(float x) {
+ __half h = __float2half_rn(x);
+ uint16_t bits = __half_as_ushort(h);
+ uint16_t key = (bits & 0x8000) ? static_cast(~bits & 0xFFFF)
+ : static_cast(bits | 0x8000);
+ return static_cast(key >> 8);
+}
+
+__device__ __forceinline__ uint32_t convert_to_uint32(float x) {
+ uint32_t bits = __float_as_uint(x);
+ return (bits & 0x80000000u) ? (~bits & 0xFFFFFFFFu) : (bits | 0x80000000u);
+}
+
+template
+__device__ __forceinline__ auto
+radix_topk(Indexer indexer, Loader loader, uint32_t length, int topk,
+ int *__restrict__ index, int &__restrict__ s_counter,
+ int (&__restrict__ s_histogram)[LENGTH],
+ int &__restrict__ s_remain_cnt,
+ int (&__restrict__ s_remain_idx)[MAX_REMAIN]) -> int {
+ constexpr auto RADIX = LENGTH - 1;
+ static_assert(RADIX > 1 && (RADIX & (RADIX - 1)) == 0,
+ "RADIX must be power of 2");
+ static_assert(RADIX <= kThreadsPerBlock);
+ __shared__ uint32_t s_threshold_bin_id;
+
+ const auto tx = threadIdx.x;
+ if (tx < RADIX + 1)
+ s_histogram[tx] = 0;
+ __syncthreads();
+
+ /// NOTE: Use uint32_t as the index
+ for (auto i = tx; i < length; i += kThreadsPerBlock) {
+ const auto idx = indexer(i);
+ const auto bin = loader(idx);
+ ::atomicAdd(&s_histogram[bin], 1);
+ }
+ __syncthreads();
+
+ // cumsum (descending)
+ if (tx == 0) {
+ s_histogram[RADIX] = 0;
+ s_remain_cnt = 0;
+ for (int i = RADIX - 2; i >= 0; --i) {
+ s_histogram[i] += s_histogram[i + 1];
+ }
+ // threshold bin
+ for (int i = 0; i < RADIX; i++) {
+ if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
+ s_threshold_bin_id = i;
+ break;
+ }
+ }
+ }
+ __syncthreads();
+
+ const auto threshold_bin = s_threshold_bin_id;
+ const auto new_topk = topk - s_histogram[threshold_bin + 1];
+
+ for (auto i = tx; i < length; i += kThreadsPerBlock) {
+ const auto idx = indexer(i);
+ const auto bin_id = static_cast(loader(idx));
+ if (bin_id > threshold_bin) {
+ index[::atomicAdd(&s_counter, 1)] = idx;
+ } else if (bin_id == threshold_bin && new_topk > 0) {
+ if constexpr (Is_Epilogue) {
+ index[::atomicAdd(&s_counter, 1)] = idx;
+ } else {
+ if (const auto cnt = ::atomicAdd(&s_remain_cnt, 1);
+ C10_LIKELY(cnt < MAX_REMAIN)) {
+ s_remain_idx[cnt] = idx;
+ }
+ }
+ }
+ }
+ __syncthreads();
+
+ return new_topk;
+}
+
+__device__ void fast_topk_cuda(const float *__restrict__ input,
+ int *__restrict__ index, int length,
+ int topk = TopK) {
+ constexpr auto RADIX = 256;
+ constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
+
+ __shared__ int s_histogram[RADIX + 1];
+ __shared__ int s_num_input[2];
+ __shared__ int s_counter;
+
+ // allocate for two rounds
+ extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
+ s_counter = 0;
+
+ // collect candidates
+ const auto indexer = [](int idx) { return idx; };
+ const auto loader = [&input](int idx) {
+ return convert_to_uint8(input[idx]);
+ };
+ int new_topk = radix_topk(indexer, loader, length, topk, index, s_counter,
+ s_histogram, s_num_input[0], s_input_idx[0]);
+ if (new_topk <= 0)
+ return;
+
+ // round 0
+ const auto indexer_0 = [](int idx) { return s_input_idx[0][idx]; };
+ const auto loader_0 = [&input](int idx) {
+ return (convert_to_uint32(input[idx]) >> 24) & 0xFF;
+ };
+ new_topk = radix_topk(indexer_0, loader_0, s_num_input[0], new_topk, index,
+ s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
+ if (new_topk <= 0)
+ return;
+
+ // round 1
+ const auto indexer_1 = [](int idx) { return s_input_idx[1][idx]; };
+ const auto loader_1 = [&input](int idx) {
+ return (convert_to_uint32(input[idx]) >> 16) & 0xFF;
+ };
+ new_topk = radix_topk(indexer_1, loader_1, s_num_input[1], new_topk, index,
+ s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
+ if (new_topk <= 0)
+ return;
+
+ // round 2
+ const auto loader_2 = [&input](int idx) {
+ return (convert_to_uint32(input[idx]) >> 8) & 0xFF;
+ };
+ new_topk = radix_topk(indexer_0, loader_2, s_num_input[0], new_topk, index,
+ s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
+ if (new_topk <= 0)
+ return;
+
+ // round 3
+ const auto loader_3 = [&input](int idx) {
+ return convert_to_uint32(input[idx]) & 0xFF;
+ };
+ // epilogue
+ radix_topk(indexer_1, loader_3, s_num_input[1], new_topk, index,
+ s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
+}
+
+__device__ void fast_topk_cuda_tl(const float *__restrict__ input,
+ int *__restrict__ index, int length,
+ int topk = TopK) {
+ constexpr auto BLOCK_SIZE = 1024;
+ constexpr auto RADIX = 256;
+ constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
+
+ __shared__ int s_threshold_bin_id;
+ __shared__ int s_histogram[RADIX + 1];
+ __shared__ int s_num_input[2];
+ __shared__ int s_counter;
+
+ // allocate for two rounds
+ extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
+
+ int tx = threadIdx.x;
+
+ // stage 1: 8bit coarse histogram
+ if (tx < RADIX + 1)
+ s_histogram[tx] = 0;
+ __syncthreads();
+
+ for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
+ const auto bin = convert_to_uint8(input[idx]);
+ ::atomicAdd(&s_histogram[bin], 1);
+ }
+ __syncthreads();
+
+ // cumsum (descending)
+ if (tx == 0) {
+ for (int i = RADIX - 2; i >= 0; --i) {
+ s_histogram[i] += s_histogram[i + 1];
+ }
+ // threshold bin
+ for (int i = 0; i < RADIX; i++) {
+ if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
+ s_threshold_bin_id = i;
+ break;
+ }
+ }
+ s_num_input[0] = 0;
+ s_counter = 0;
+ }
+ __syncthreads();
+
+ int threshold_bin = s_threshold_bin_id;
+ int new_topk = topk - s_histogram[threshold_bin + 1];
+
+ // collect candidates
+ for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
+ const auto bin_id = static_cast(convert_to_uint8(input[idx]));
+ if (bin_id > threshold_bin) {
+ int pos = ::atomicAdd(&s_counter, 1);
+ index[pos] = idx;
+ } else if (bin_id == threshold_bin && new_topk > 0) {
+ int pos = ::atomicAdd(&s_num_input[0], 1);
+ if (pos < SMEM_INPUT_SIZE) {
+ [[likely]] s_input_idx[0][pos] = idx;
+ }
+ }
+ }
+ __syncthreads();
+
+ // stage 2: refine with 8bit radix passes
+#pragma unroll 4
+ for (int round = 0; round < 4; ++round) {
+ if (new_topk <= 0)
+ break;
+ int r_idx = round % 2;
+
+ // reset
+ if (tx < RADIX + 1)
+ s_histogram[tx] = 0;
+ __syncthreads();
+
+ int num_input = s_num_input[r_idx];
+ for (int i = tx; i < num_input; i += BLOCK_SIZE) {
+ int idx = s_input_idx[r_idx][i];
+ uint32_t bin32 =
+ (convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
+ ::atomicAdd(&s_histogram[bin32], 1);
+ }
+ __syncthreads();
+
+ if (tx == 0) {
+ for (int i = RADIX - 2; i >= 0; --i)
+ s_histogram[i] += s_histogram[i + 1];
+ for (int i = 0; i < RADIX; i++) {
+ if (s_histogram[i] >= new_topk && s_histogram[i + 1] < new_topk) {
+ s_threshold_bin_id = i;
+ break;
+ }
+ }
+ s_num_input[r_idx ^ 1] = 0;
+ }
+ __syncthreads();
+
+ new_topk -= s_histogram[s_threshold_bin_id + 1];
+ int threshold_bin = s_threshold_bin_id;
+
+ for (int i = tx; i < num_input; i += BLOCK_SIZE) {
+ int idx = s_input_idx[r_idx][i];
+ uint32_t bin32 =
+ (convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
+ if (bin32 > threshold_bin) {
+ int pos = ::atomicAdd(&s_counter, 1);
+ index[pos] = idx;
+ } else if (bin32 == threshold_bin && new_topk > 0) {
+ if (round == 3) {
+ int pos = ::atomicAdd(&s_counter, 1);
+ index[pos] = idx;
+ } else {
+ int pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
+ if (pos < SMEM_INPUT_SIZE)
+ s_input_idx[r_idx ^ 1][pos] = idx;
+ }
+ }
+ }
+ __syncthreads();
+ }
+}
+
+__global__ void topk_kernel(const FastTopKParams params) {
+ const auto &[input, indices, lengths, input_stride, use_tilelang] = params;
+ const auto bid = blockIdx.x;
+ const auto length = *(lengths + bid);
+ const auto indice = indices + bid * TopK;
+ const auto score = input + bid * input_stride;
+ if (length <= TopK) {
+ return naive_topk_cuda(score, indice, length);
+ } else {
+ if (use_tilelang) {
+ return fast_topk_cuda_tl(score, indice, length);
+ } else {
+ return fast_topk_cuda(score, indice, length);
+ }
+ }
+}
+
+__global__ void topk_kernel_transform_decode( // decode
+ const FastTopKParams params, int32_t *__restrict__ dst_page_table,
+ const int32_t *__restrict__ src_page_table, const int64_t src_stride) {
+ const auto &[input, _, lengths, input_stride, use_tilelang] = params;
+ const auto bid = blockIdx.x;
+ const auto tid = threadIdx.x;
+ const auto length = *(lengths + bid);
+ const auto src_page_entry = src_page_table + bid * src_stride;
+ const auto dst_page_entry = dst_page_table + bid * TopK;
+ const auto score = input + bid * input_stride;
+ if (length <= TopK) {
+ return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
+ } else {
+ __shared__ int s_indices[TopK];
+ if (use_tilelang) {
+ fast_topk_cuda_tl(score, s_indices, length);
+ } else {
+ fast_topk_cuda(score, s_indices, length);
+ }
+ // copy src[s_indices] to dst, we manually unroll here
+ static_assert(TopK % kThreadsPerBlock == 0);
+ static_assert(TopK / kThreadsPerBlock == 2);
+ const auto idx_0 = tid;
+ const auto pos_0 = s_indices[idx_0];
+ dst_page_entry[idx_0] = src_page_entry[pos_0];
+ const auto idx_1 = tid + kThreadsPerBlock;
+ const auto pos_1 = s_indices[idx_1];
+ dst_page_entry[idx_1] = src_page_entry[pos_1];
+ }
+}
+
+__global__ void topk_kernel_transform_prefill( // prefill
+ const FastTopKParams params, int32_t *__restrict__ dst_page_table,
+ const int32_t *__restrict__ src_page_table, const int64_t src_stride,
+ const int32_t *__restrict__ cu_seqlens, const int64_t prefill_bs) {
+ const auto &[input, _, lengths, input_stride, use_tilelang] = params;
+ const auto bid = blockIdx.x;
+ const auto tid = threadIdx.x;
+ const auto length = *(lengths + bid);
+ const auto dst_page_entry = dst_page_table + bid * TopK;
+ const auto score = input + bid * input_stride;
+
+ /// NOTE: prefill bs is usually small, we can just use a simple loop here
+ /// We ensure that last cu_seqlens is equal to number of blocks launched
+ assert(gridDim.x == cu_seqlens[prefill_bs] &&
+ "Invalid cu_seqlens in topk-transform-prefill");
+ __shared__ const int32_t *s_src_page_entry;
+ if (tid == 0) {
+ for (int64_t offset = 0; offset < prefill_bs; ++offset) {
+ if (bid < cu_seqlens[offset + 1]) {
+ s_src_page_entry = src_page_table + offset * src_stride;
+ break;
+ }
+ }
+ }
+ __syncthreads();
+ const auto src_page_entry = s_src_page_entry;
+
+ if (length <= TopK) {
+ return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
+ } else {
+ __shared__ int s_indices[TopK];
+ if (use_tilelang) {
+ fast_topk_cuda_tl(score, s_indices, length);
+ } else {
+ fast_topk_cuda(score, s_indices, length);
+ }
+ // copy src[s_indices] to dst, we manually unroll here
+ static_assert(TopK % kThreadsPerBlock == 0);
+ static_assert(TopK / kThreadsPerBlock == 2);
+ const auto idx_0 = tid;
+ const auto pos_0 = s_indices[idx_0];
+ dst_page_entry[idx_0] = src_page_entry[pos_0];
+ const auto idx_1 = tid + kThreadsPerBlock;
+ const auto pos_1 = s_indices[idx_1];
+ dst_page_entry[idx_1] = src_page_entry[pos_1];
+ }
+}
+
+auto get_params(at::Tensor score, at::Tensor lengths, bool use_tilelang,
+ std::optional indices_opt = std::nullopt)
+ -> FastTopKParams {
+ const auto B = score.size(0);
+ TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
+ TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
+ TORCH_CHECK(lengths.size(0) == B);
+ int32_t *indices_data_ptr = nullptr;
+ if (indices_opt.has_value()) {
+ const auto &indices = indices_opt.value();
+ TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
+ TORCH_CHECK(indices.size(0) == B);
+ TORCH_CHECK(indices.size(1) == TopK);
+ indices_data_ptr = indices.data_ptr();
+ }
+
+ return FastTopKParams{
+ .input = score.data_ptr(),
+ .indices = indices_data_ptr,
+ .lengths = lengths.data_ptr(),
+ .input_stride = score.stride(0),
+ .use_tilelang = use_tilelang,
+ };
+}
+
+template
+auto setup_kernel_smem_once() -> void {
+ [[maybe_unused]]
+ static const auto result = [] {
+ return ::cudaFuncSetAttribute(
+ f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem);
+ }();
+ TORCH_CHECK(result == cudaSuccess,
+ "set_up_kernel_once failed:", ::cudaGetErrorString(result));
+}
+
+auto fast_topk_interface(at::Tensor score, at::Tensor indices,
+ at::Tensor lengths, bool use_tilelang) -> void {
+ const auto params = get_params(score, lengths, use_tilelang, indices);
+ const auto B = score.size(0);
+ const auto stream = at::cuda::getCurrentCUDAStream().stream();
+ const auto grid = dim3{static_cast(B)};
+ const auto block = dim3{kThreadsPerBlock};
+ setup_kernel_smem_once();
+ topk_kernel<<>>(params);
+ const auto result = cudaGetLastError();
+ TORCH_CHECK(result == cudaSuccess,
+ "topk kernel failed:", ::cudaGetErrorString(result));
+}
+
+auto fast_topk_transform_interface(at::Tensor score, at::Tensor lengths,
+ at::Tensor dst_page_table,
+ at::Tensor src_page_table,
+ at::Tensor cu_seqlens,
+ bool use_tilelang) -> void {
+ const auto params = get_params(score, lengths, use_tilelang);
+ const auto B = score.size(0);
+ TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
+ TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
+ TORCH_CHECK(cu_seqlens.dim() == 1 && cu_seqlens.is_contiguous());
+ const auto prefill_bs = cu_seqlens.size(0) - 1;
+ TORCH_CHECK(dst_page_table.size(0) == B);
+ TORCH_CHECK(dst_page_table.size(1) == TopK);
+ TORCH_CHECK(src_page_table.size(0) == prefill_bs);
+ TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
+
+ // launch kernel
+ const auto stream = at::cuda::getCurrentCUDAStream().stream();
+ const auto grid = dim3{static_cast(B)};
+ const auto block = dim3{kThreadsPerBlock};
+ const auto src_stride = src_page_table.stride(0);
+
+ // dispatch to decode or prefill
+ const auto is_decode = (prefill_bs == B);
+ if (is_decode) {
+ setup_kernel_smem_once();
+ topk_kernel_transform_decode<<>>(
+ params, dst_page_table.data_ptr(),
+ src_page_table.data_ptr(), src_stride);
+ } else {
+ setup_kernel_smem_once();
+ topk_kernel_transform_prefill<<>>(
+ params, dst_page_table.data_ptr(),
+ src_page_table.data_ptr(), src_stride,
+ cu_seqlens.data_ptr(), prefill_bs);
+ }
+
+ const auto result = cudaGetLastError();
+ TORCH_CHECK(result == cudaSuccess,
+ "topk kernel failed:", ::cudaGetErrorString(result));
+}
+
+} // namespace
+
+PYBIND11_MODULE(topk_kernel, m) {
+ m.def("fast_topk", &fast_topk_interface);
+ m.def("fast_topk_transform", &fast_topk_transform_interface);
+}
diff --git a/python/sglang/srt/layers/attention/nsa/cuda/topk.py b/python/sglang/srt/layers/attention/nsa/cuda/topk.py
new file mode 100644
index 000000000..389289644
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/cuda/topk.py
@@ -0,0 +1,39 @@
+from __future__ import annotations
+
+from typing import Any
+
+import torch
+
+from .utils import load_kernel_module
+
+
+def _load_topk_module() -> Any:
+ """
+ Load the index manipulation module.
+ """
+ return load_kernel_module("topk.cu", "topk_kernel")
+
+
+# TODO(dark): configure out why my cuda impl is a little slower....
+# I believe it has something to do with unrolling loops (?)
+_USE_TL = True
+
+
+def fast_topk(
+ score: torch.Tensor,
+ indices: torch.Tensor,
+ lengths: torch.Tensor,
+) -> torch.Tensor:
+ return _load_topk_module().fast_topk(score, indices, lengths, _USE_TL)
+
+
+def fast_topk_transform(
+ score: torch.Tensor,
+ lengths: torch.Tensor,
+ dst_page_table: torch.Tensor,
+ src_page_table: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+) -> torch.Tensor:
+ return _load_topk_module().fast_topk_transform(
+ score, lengths, dst_page_table, src_page_table, cu_seqlens, _USE_TL
+ )
diff --git a/python/sglang/srt/layers/attention/nsa/cuda/utils.py b/python/sglang/srt/layers/attention/nsa/cuda/utils.py
new file mode 100644
index 000000000..7daffede8
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/cuda/utils.py
@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+import os
+from functools import lru_cache
+from typing import Any, Iterable
+
+
+@lru_cache()
+def _prepare_for_load() -> str:
+ import os
+ import warnings
+
+ warnings.filterwarnings(
+ "ignore", category=UserWarning, module="torch.utils.cpp_extension"
+ )
+ return os.path.dirname(os.path.abspath(__file__))
+
+
+@lru_cache()
+def load_kernel_module(
+ path: str | Iterable[str],
+ name: str,
+ *,
+ build: str = "build",
+ cflags: Iterable[str] | None = None,
+ cuda_flags: Iterable[str] | None = None,
+ ldflags: Iterable[str] | None = None,
+) -> Any:
+ from torch.utils.cpp_extension import load
+
+ if isinstance(path, str):
+ path = (path,)
+
+ abs_path = _prepare_for_load()
+ build_dir = f"{abs_path}/{build}"
+ os.makedirs(build_dir, exist_ok=True)
+ return load(
+ name=name,
+ sources=[f"{abs_path}/csrc/{p}" for p in path],
+ extra_cflags=list(cflags or []) or ["-O3", "-std=c++17"],
+ extra_cuda_cflags=list(cuda_flags or []) or ["-O3", "-std=c++17"],
+ extra_ldflags=list(ldflags or []) or None,
+ build_directory=build_dir,
+ )
diff --git a/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py b/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
new file mode 100644
index 000000000..b6c2269f5
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
@@ -0,0 +1,163 @@
+import torch
+import triton
+import triton.language as tl
+
+from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST
+
+
+def dequantize_k_cache(quant_k_cache):
+ if NSA_DEQUANT_K_CACHE_FAST:
+ return _dequantize_k_cache_fast_wrapped(quant_k_cache)
+ else:
+ return _dequantize_k_cache_slow(quant_k_cache)
+
+
+def _dequantize_k_cache_slow(
+ quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
+ dv: int = 512,
+ tile_size: int = 128,
+ d: int = 576,
+) -> torch.Tensor:
+ """
+ De-quantize the k-cache
+ """
+ assert dv % tile_size == 0
+ num_tiles = dv // tile_size
+ num_blocks, block_size, h_k, _ = quant_k_cache.shape
+ assert h_k == 1
+ result = torch.empty(
+ (num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
+ )
+
+ quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
+
+ input_nope = quant_k_cache[..., :dv]
+ input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
+ input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
+ result[..., dv:] = input_rope
+
+ for tile_idx in range(0, num_tiles):
+ cur_nope = input_nope[
+ ..., tile_idx * tile_size : (tile_idx + 1) * tile_size
+ ].to(torch.float32)
+ cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
+ result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
+ cur_nope * cur_scales
+ )
+
+ result = result.view(num_blocks, block_size, 1, d)
+ return result
+
+
+def _dequantize_k_cache_fast_wrapped(
+ quant_k_cache: torch.Tensor,
+ dv: int = 512,
+ tile_size: int = 128,
+) -> torch.Tensor:
+ # TODO the final API may be 2D instead of 4D, thus we convert them here
+ num_blocks, block_size, _, dim_quant = quant_k_cache.shape
+ assert dv == 512
+ assert dim_quant == 656
+ assert tile_size == 128
+ quant_k_cache = quant_k_cache.view((-1, dim_quant))
+
+ output = _dequantize_k_cache_fast(quant_k_cache)
+
+ return output.view(num_blocks, block_size, 1, -1)
+
+
+def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
+ num_tokens, dim_quant = quant_k_cache.shape
+
+ assert quant_k_cache.dtype == torch.float8_e4m3fn
+ dim_nope = 512
+ dim_rope = 64
+ num_tiles = dim_nope // group_size
+ assert dim_quant == 656
+
+ output = torch.empty(
+ (num_tokens, dim_nope + dim_rope),
+ dtype=torch.bfloat16,
+ device=quant_k_cache.device,
+ )
+
+ num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
+ assert num_blocks_per_token == 5
+
+ assert dim_nope % group_size == 0
+ NUM_NOPE_BLOCKS = dim_nope // group_size
+
+ input_nope_q = quant_k_cache[:, :dim_nope]
+ input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
+ torch.float32
+ )
+ input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
+
+ _dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
+ output,
+ input_nope_q,
+ input_nope_s,
+ input_rope,
+ output.stride(0),
+ input_nope_q.stride(0),
+ input_nope_s.stride(0),
+ input_rope.stride(0),
+ NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
+ GROUP_SIZE=group_size,
+ DIM_NOPE=dim_nope,
+ DIM_ROPE=dim_rope,
+ )
+
+ return output
+
+
+@triton.jit
+def _dequantize_k_cache_fast_kernel(
+ output_ptr,
+ input_nope_q_ptr,
+ input_nope_s_ptr,
+ input_rope_ptr,
+ output_stride_0: int,
+ input_nope_q_stride_0: int,
+ input_nope_s_stride_0: int,
+ input_rope_stride_0: int,
+ NUM_NOPE_BLOCKS: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+ DIM_NOPE: tl.constexpr,
+ DIM_ROPE: tl.constexpr,
+):
+ token_id = tl.program_id(0)
+ raw_block_id = tl.program_id(1)
+
+ if raw_block_id < NUM_NOPE_BLOCKS:
+ # a. dequant nope
+ effective_block_id = raw_block_id
+
+ offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
+ mask = offs_q < DIM_NOPE
+ ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q
+ ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id
+
+ y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
+ y_s = tl.load(ptr_s)
+
+ y = (y_q * y_s).to(output_ptr.dtype.element_ty)
+
+ dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
+ tl.store(dst_ptr, y, mask=mask)
+ else:
+ # b. copy rope
+ effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
+
+ offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
+ mask = offs < DIM_ROPE
+
+ src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs
+ dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
+
+ data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
+ tl.store(dst_ptr, data, mask=mask)
+
+
+if __name__ == "__main__":
+ raise Exception("UT is in quant_k_cache.py")
diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
new file mode 100644
index 000000000..d887cfddd
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
@@ -0,0 +1,354 @@
+from typing import TYPE_CHECKING
+
+import torch
+import triton
+import triton.language as tl
+
+if TYPE_CHECKING:
+ from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
+
+"""
+k: data, 128 item per token, fp8
+s: scale, 1 item per token, fp32
+"""
+
+
+class GetK:
+ @classmethod
+ def execute(cls, *args, **kwargs):
+ return cls.torch_fast(*args, **kwargs)
+
+ @classmethod
+ def slow(
+ cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
+ ):
+ num_pages = (seq_len + pool.page_size - 1) // pool.page_size
+ seq_len_ = num_pages * pool.page_size
+ index_k_fp8 = torch.empty(
+ (seq_len_, pool.index_head_dim),
+ dtype=torch.uint8,
+ device=pool.device,
+ )
+ for i in range(num_pages):
+ page_index = page_indices[i]
+ index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
+ page_index
+ ][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
+
+ return index_k_fp8[:seq_len]
+
+ @classmethod
+ def torch_fast(
+ cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
+ ):
+ """
+ :param page_indices: (num_pages,), int32
+ :return: (seq_len, index_head_dim), uint8
+ """
+
+ # can handle per 128B instead of per element
+
+ # page_indices: (num_pages,), element := a page index
+ buf_numel_per_page = buf.shape[1]
+
+ num_k_bytes_per_page = pool.page_size * pool.index_head_dim
+ num_k_bytes_per_token = pool.index_head_dim
+
+ # buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
+ # flat_buf: (whatever,), uint8
+ flat_buf = buf.flatten()
+
+ # flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
+ flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
+ num_k_bytes_per_page, dtype=torch.int32, device="cuda"
+ )[None, :]
+ flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
+
+ out = flat_buf[flat_indices]
+ return out.view(-1, 128)
+
+
+class GetS:
+ @classmethod
+ def execute(cls, *args, **kwargs):
+ return cls.torch_fast(*args, **kwargs)
+
+ @classmethod
+ def slow(
+ cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
+ ):
+ num_pages = (seq_len + pool.page_size - 1) // pool.page_size
+ seq_len_ = num_pages * pool.page_size
+ assert pool.index_head_dim // pool.quant_block_size == 1
+ index_k_scale_fp8 = torch.empty(
+ (seq_len_, 4),
+ dtype=torch.uint8,
+ device=pool.device,
+ )
+ for i in range(num_pages):
+ page_index = page_indices[i]
+ index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
+ page_index
+ ][pool.page_size * pool.index_head_dim :].view(-1, 4)
+ return index_k_scale_fp8[:seq_len]
+
+ @classmethod
+ def torch_fast(
+ cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
+ ):
+ """
+ :param page_indices: (num_pages,), int32
+ :return: (seq_len, index_head_dim // quant_block_size), uint8
+ """
+ buf_numel_per_page = buf.shape[1]
+
+ num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
+ num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
+ s_offset_in_page = pool.page_size * pool.index_head_dim
+
+ flat_buf = buf.flatten()
+ flat_indices = (
+ (page_indices * buf_numel_per_page)[:, None]
+ + torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
+ None, :
+ ]
+ + s_offset_in_page
+ )
+ flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
+
+ out = flat_buf[flat_indices]
+ return out.view(-1, 4)
+
+
+class SetK:
+ @classmethod
+ def execute(cls, *args, buf, **kwargs):
+ return cls.torch_fast(*args, **kwargs, buf=buf)
+
+ @classmethod
+ def slow(
+ cls,
+ pool: "NSATokenToKVPool",
+ buf: torch.Tensor,
+ loc: torch.Tensor,
+ index_k: torch.Tensor,
+ ):
+ for i in range(len(loc)):
+ page_index = loc[i] // pool.page_size
+ offset = loc[i] % pool.page_size
+ buf[
+ page_index,
+ offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
+ ] = index_k[i].view(torch.uint8)
+
+ @classmethod
+ def torch_fast(
+ cls,
+ pool: "NSATokenToKVPool",
+ buf: torch.Tensor,
+ loc: torch.Tensor,
+ index_k: torch.Tensor,
+ ):
+ (num_tokens_to_write,) = loc.shape
+ buf_numel_per_page = buf.shape[1]
+ num_k_bytes_per_token = pool.index_head_dim
+
+ # loc: (num_tokens_to_write,), int32, element := the token index to write to
+ loc_page_index = loc // pool.page_size
+ loc_token_offset_in_page = loc % pool.page_size
+
+ flat_buf = buf.flatten()
+ flat_indices = (
+ (loc_page_index * buf_numel_per_page)[:, None]
+ + (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
+ + torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
+ None, :
+ ]
+ )
+ num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
+ flat_indices = flat_indices.flatten()[:num_k_bytes_total]
+ flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
+
+
+class SetS:
+ @classmethod
+ def execute(cls, *args, buf, **kwargs):
+ return cls.torch_fast(*args, **kwargs, buf=buf)
+
+ @classmethod
+ def slow(
+ cls,
+ pool: "NSATokenToKVPool",
+ buf: torch.Tensor,
+ loc: torch.Tensor,
+ index_k_scale: torch.Tensor,
+ ):
+ for i in range(len(loc)):
+ page_index = loc[i] // pool.page_size
+ offset = loc[i] % pool.page_size
+ start = pool.page_size * pool.index_head_dim
+ buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
+ index_k_scale[i].view(torch.uint8)
+ )
+
+ @classmethod
+ def torch_fast(
+ cls,
+ pool: "NSATokenToKVPool",
+ buf: torch.Tensor,
+ loc: torch.Tensor,
+ index_k_scale: torch.Tensor,
+ ):
+ (num_tokens_to_write,) = loc.shape
+ buf_numel_per_page = buf.shape[1]
+ num_s_bytes_per_token = 4
+ s_offset_in_page = pool.page_size * pool.index_head_dim
+
+ # loc: (num_tokens_to_write,), int32, element := the token index to write to
+ loc_page_index = loc // pool.page_size
+ loc_token_offset_in_page = loc % pool.page_size
+
+ flat_buf = buf.flatten()
+ flat_indices = (
+ (loc_page_index * buf_numel_per_page)[:, None]
+ + s_offset_in_page
+ + (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
+ + torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
+ None, :
+ ]
+ )
+ number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
+ flat_indices = flat_indices.flatten()[:number_s_bytes_total]
+ flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
+
+
+class SetKAndS:
+ @classmethod
+ def execute(cls, *args, buf, **kwargs):
+ if 0:
+ # print("SetK, SetS comparison test")
+ buf_cloned = buf.clone()
+ cls.vanilla(*args, **kwargs, buf=buf)
+ cls.triton(*args, **kwargs, buf=buf_cloned)
+
+ def _clear_token_0(target):
+ target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
+
+ _clear_token_0(buf)
+ _clear_token_0(buf_cloned)
+
+ assert torch.all(
+ buf == buf_cloned
+ ), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
+ return
+
+ cls.triton(*args, **kwargs, buf=buf)
+
+ @classmethod
+ def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
+ SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
+ SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
+
+ @classmethod
+ def triton(cls, pool, buf, loc, index_k, index_k_scale):
+ _set_k_and_s_triton(
+ buf=buf,
+ loc=loc,
+ index_k=index_k,
+ index_k_scale=index_k_scale,
+ page_size=pool.page_size,
+ )
+
+
+def _set_k_and_s_triton(
+ buf: torch.Tensor,
+ loc: torch.Tensor,
+ index_k: torch.Tensor,
+ index_k_scale: torch.Tensor,
+ page_size: int,
+):
+ """
+ :param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
+ :param loc: (num_tokens_to_write,), int, element := the token index to write to
+ :param index_k: (num_tokens_to_write, 128 elem), fp8
+ :param index_k_scale: (num_tokens_to_write, 1 elem), fp32
+ :return:
+ """
+ num_pages, buf_numel_per_page = buf.shape
+ (num_tokens_to_write,) = loc.shape
+ num_tokens_to_write_, index_head_dim = index_k.shape
+ num_tokens_to_write__, scale_dim = index_k_scale.shape
+ assert buf_numel_per_page == 64 * (128 + 4)
+ assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
+ assert index_head_dim == 128
+ assert scale_dim == 1
+ assert page_size == 64
+
+ assert buf.dtype == torch.uint8
+ assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
+ assert index_k.dtype == torch.float8_e4m3fn
+ assert index_k_scale.dtype == torch.float32
+
+ assert buf.is_contiguous()
+ assert loc.is_contiguous()
+ assert index_k.is_contiguous()
+ assert index_k_scale.is_contiguous()
+
+ buf_fp8 = buf.view(torch.float8_e4m3fn)
+ buf_fp32 = buf.view(torch.float32)
+
+ _set_k_and_s_triton_kernel[(num_tokens_to_write,)](
+ buf_fp8,
+ buf_fp32,
+ loc,
+ index_k,
+ index_k_scale,
+ index_k.stride(0),
+ PAGE_SIZE=page_size,
+ BUF_NUMEL_PER_PAGE=buf_numel_per_page,
+ NUM_K_ELEMS_PER_TOKEN=index_head_dim,
+ S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
+ )
+
+
+@triton.jit
+def _set_k_and_s_triton_kernel(
+ buf_fp8_ptr,
+ buf_fp32_ptr,
+ loc_ptr,
+ index_k_ptr,
+ index_k_scale_ptr,
+ index_k_ptr_stride_0,
+ PAGE_SIZE: tl.constexpr,
+ BUF_NUMEL_PER_PAGE: tl.constexpr,
+ NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
+ S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
+):
+ token_id = tl.program_id(0)
+
+ loc = tl.load(loc_ptr + token_id)
+
+ in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
+
+ # no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
+ k = tl.load(index_k_ptr + in_k_offsets)
+ k_scale = tl.load(index_k_scale_ptr + token_id)
+
+ loc_page_index = loc // PAGE_SIZE
+ loc_token_offset_in_page = loc % PAGE_SIZE
+
+ out_k_offsets = (
+ loc_page_index * BUF_NUMEL_PER_PAGE
+ + loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
+ + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
+ )
+
+ # "//4" b/c it is fp32 instead of uint8
+ out_s_offset = (
+ loc_page_index * BUF_NUMEL_PER_PAGE // 4
+ + S_OFFSET_NBYTES_IN_PAGE // 4
+ + loc_token_offset_in_page
+ )
+
+ tl.store(buf_fp8_ptr + out_k_offsets, k)
+ tl.store(buf_fp32_ptr + out_s_offset, k_scale)
diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
new file mode 100644
index 000000000..922cd2974
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py
@@ -0,0 +1,682 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+
+from sglang.srt.custom_op import CustomOp
+from sglang.srt.debug_utils.dumper import dumper
+from sglang.srt.utils import add_prefix, is_npu
+
+if not is_npu():
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
+ import deep_gemm
+
+from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
+from sglang.srt.layers.dp_attention import get_attention_tp_group
+from sglang.srt.layers.linear import ReplicatedLinear
+from sglang.srt.layers.quantization import deep_gemm_wrapper
+from sglang.srt.layers.quantization.base_config import QuantizationConfig
+from sglang.srt.layers.rotary_embedding import get_rope_wrapper
+from sglang.srt.managers.schedule_batch import global_server_args_dict
+from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch
+from sglang.srt.utils import add_prefix, align, is_cuda
+
+try:
+ import deep_gemm_v32
+except ImportError as e:
+ print("Error when importing deep_gemm_v32, try deep_gemm")
+ try:
+ import deep_gemm as deep_gemm_v32
+ except ImportError as e:
+ print("Error when importing deep_gemm, skip")
+
+
+if TYPE_CHECKING:
+ from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
+
+DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
+
+
+class BaseIndexerMetadata(ABC):
+ @abstractmethod
+ def get_seqlens_int32(self) -> torch.Tensor:
+ """
+ Return: (batch_size,) int32 tensor
+ """
+
+ @abstractmethod
+ def get_page_table_64(self) -> torch.Tensor:
+ """
+ Return: (batch_size, num_blocks) int32, page table.
+ The page size of the table is 64.
+ """
+
+ @abstractmethod
+ def get_seqlens_expanded(self) -> torch.Tensor:
+ """
+ Return: (sum_extend_seq_len,) int32 tensor
+ """
+
+ @abstractmethod
+ def topk_transform(
+ self,
+ logits: torch.Tensor,
+ topk: int,
+ ) -> torch.Tensor:
+ """
+ Perform topk selection on the logits and possibly transform the result.
+
+ NOTE that attention backend may override this function to do some
+ transformation, which means the result of this topk_transform may not
+ be the topk indices of the input logits.
+
+ Return: Anything, since it will be passed to the attention backend
+ for further processing on sparse attention computation.
+ Don't assume it is the topk indices of the input logits.
+ """
+
+
+def rotate_activation(x: torch.Tensor) -> torch.Tensor:
+ assert x.dtype == torch.bfloat16
+ from fast_hadamard_transform import hadamard_transform
+
+ hidden_size = x.size(-1)
+ assert (
+ hidden_size & (hidden_size - 1)
+ ) == 0, "Hidden size must be a power of 2 for Hadamard transform."
+ return hadamard_transform(x, scale=hidden_size**-0.5)
+
+
+class V32LayerNorm(nn.Module):
+ """
+ Layer Normalization.
+ """
+
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
+ self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
+
+ def forward(self, x: torch.Tensor):
+ return F.layer_norm(
+ x.float(), (self.dim,), self.weight, self.bias, self.eps
+ ).type_as(x)
+
+
+class Indexer(CustomOp):
+ def __init__(
+ self,
+ hidden_size: int,
+ index_n_heads: int,
+ index_head_dim: int,
+ rope_head_dim: int,
+ index_topk: int,
+ q_lora_rank: int,
+ max_position_embeddings: int,
+ rope_theta: float,
+ layer_id: int,
+ scale_fmt: Optional[str],
+ block_size: int = 128,
+ rope_scaling: Optional[Dict[str, Any]] = None,
+ prefix: str = "",
+ quant_config: Optional[QuantizationConfig] = None,
+ alt_stream: Optional[torch.cuda.Stream] = None,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.n_heads = index_n_heads
+ self.head_dim = index_head_dim
+ self.rope_head_dim = rope_head_dim
+ self.index_topk = index_topk
+ self.q_lora_rank = q_lora_rank
+ self.layer_id = layer_id
+ self.alt_stream = alt_stream
+ if not is_npu():
+ self.sm_count = deep_gemm.get_num_sms()
+ self.half_device_sm_count = align(self.sm_count // 2, 8)
+
+ self.wq_b = ReplicatedLinear(
+ self.q_lora_rank,
+ self.n_heads * self.head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=add_prefix("wq_b", prefix),
+ )
+ self.wk = ReplicatedLinear(
+ self.hidden_size,
+ self.head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=add_prefix("wk", prefix),
+ )
+ self.k_norm = V32LayerNorm(self.head_dim)
+ # NOTE: weight_proj is not quantized
+ self.weights_proj = ReplicatedLinear(
+ self.hidden_size,
+ self.n_heads,
+ bias=False,
+ prefix=add_prefix("weights_proj", prefix),
+ )
+ self.rotary_emb = get_rope_wrapper(
+ rope_head_dim,
+ rotary_dim=rope_head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta, # type: ignore
+ rope_scaling=rope_scaling,
+ is_neox_style=False,
+ device=global_server_args_dict["device"],
+ )
+ self.block_size = block_size
+ self.scale_fmt = scale_fmt
+ self.softmax_scale = self.head_dim**-0.5
+
+ def _forward_fake(
+ self,
+ x: torch.Tensor,
+ q_lora: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ layer_id: int,
+ ):
+ bs = x.shape[0]
+ assert self.index_topk == 2048
+ ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
+ None, ...
+ ].repeat(bs, 1)
+ if forward_batch.forward_mode.is_extend():
+ assert (
+ forward_batch.extend_seq_lens_cpu is not None
+ and forward_batch.seq_lens_cpu is not None
+ )
+ which = 0
+ for i, (kv_len, qo_len) in enumerate(
+ zip(
+ forward_batch.seq_lens_cpu.tolist(),
+ forward_batch.extend_seq_lens_cpu,
+ strict=True,
+ )
+ ):
+ for j in range(kv_len - qo_len, kv_len):
+ ans[which, j + 1 :] = -1
+ which += 1
+ assert which == ans.shape[0]
+ else:
+ assert forward_batch.seq_lens_cpu is not None
+ for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
+ ans[i, seq_len:] = -1
+
+ return ans
+
+ def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
+ weights, _ = self.weights_proj(x)
+ weights = weights * self.n_heads**-0.5
+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
+ return weights
+
+ def _get_q_k_bf16(
+ self,
+ q_lora: torch.Tensor,
+ x: torch.Tensor,
+ positions: torch.Tensor,
+ enable_dual_stream: bool,
+ ):
+
+ if enable_dual_stream:
+ current_stream = torch.cuda.current_stream()
+ self.alt_stream.wait_stream(current_stream)
+
+ with deep_gemm_wrapper.configure_deep_gemm_num_sms(
+ self.half_device_sm_count
+ ):
+ query, _ = self.wq_b(q_lora)
+ query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
+ q_rope, _ = torch.split(
+ query,
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
+ dim=-1,
+ )
+ with torch.cuda.stream(self.alt_stream):
+ key, _ = self.wk(x)
+ key = self.k_norm(key)
+
+ k_rope, _ = torch.split(
+ key,
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
+ dim=-1,
+ )
+
+ current_stream.wait_stream(self.alt_stream)
+ else:
+ query, _ = self.wq_b(q_lora)
+ if dumper._enable:
+ after_wq_b = query.clone()
+ query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
+
+ q_rope, _ = torch.split(
+ query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
+ )
+
+ key, _ = self.wk(x)
+ if dumper._enable:
+ after_wk = key.clone()
+ key = self.k_norm(key)
+ if dumper._enable:
+ after_k_norm = key.clone()
+ k_rope, _ = torch.split(
+ key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
+ )
+
+ q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
+
+ query[..., : self.rope_head_dim] = q_rope
+ key[..., : self.rope_head_dim] = k_rope
+
+ if dumper._enable:
+ q_before_hadamard = query.clone()
+ k_before_hadamard = key.clone()
+
+ if enable_dual_stream:
+ current_stream = torch.cuda.current_stream()
+ self.alt_stream.wait_stream(current_stream)
+ query = rotate_activation(query)
+
+ with torch.cuda.stream(self.alt_stream):
+ key = rotate_activation(key)
+ current_stream.wait_stream(self.alt_stream)
+ else:
+ query = rotate_activation(query)
+ key = rotate_activation(key)
+
+ return query, key
+
+ def _get_topk_paged(
+ self,
+ forward_batch: ForwardBatch,
+ layer_id: int,
+ q_fp8: torch.Tensor,
+ weights: torch.Tensor,
+ metadata: BaseIndexerMetadata,
+ ) -> torch.Tensor:
+ if TYPE_CHECKING:
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
+
+ page_size = forward_batch.token_to_kv_pool.page_size
+ # NOTE(dark): blocksize = 64 is hardcoded in deep_gemm_v32
+ assert page_size == 64, "only support page size 64"
+
+ # NOTE(dark): this support extend/decode/decode+graph
+ block_tables = metadata.get_page_table_64()
+
+ max_seq_len = block_tables.shape[1] * page_size
+ kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
+ layer_id=layer_id
+ )
+
+ blocksize = page_size
+ seqlens_32 = metadata.get_seqlens_int32()
+ # NOTE(dark): 132 is SM count on H200/B200, not magic number
+ schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
+ seqlens_32, blocksize, self.sm_count
+ )
+
+ assert len(q_fp8.shape) == 3
+ q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
+ assert len(kv_cache_fp8.shape) == 2
+ block_kv = 64
+ num_heads_kv = 1
+ head_dim_with_sf = 132
+ kv_cache_fp8 = kv_cache_fp8.view(
+ kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
+ )
+ assert len(weights.shape) == 3
+ weights = weights.squeeze(2)
+
+ logits = deep_gemm_v32.fp8_paged_mqa_logits(
+ q_fp8,
+ kv_cache_fp8,
+ weights,
+ seqlens_32,
+ block_tables,
+ schedule_metadata,
+ max_seq_len,
+ clean_logits=False,
+ )
+
+ # NOTE(dark): logits should be cleaned in topk_transform
+ topk_result = metadata.topk_transform(logits, self.index_topk)
+ return topk_result
+
+ def _get_topk_ragged(
+ self,
+ forward_batch: ForwardBatch,
+ layer_id: int,
+ q_fp8: torch.Tensor,
+ weights: torch.Tensor,
+ metadata: BaseIndexerMetadata,
+ ) -> torch.Tensor:
+ if TYPE_CHECKING:
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
+
+ page_size = forward_batch.token_to_kv_pool.page_size
+ assert page_size == 64, "only support page size 64"
+ assert len(weights.shape) == 3
+ weights = weights.squeeze(-1)
+ k_fp8_list = []
+ k_scale_list = []
+ ks_list = []
+ offset = 0
+
+ block_tables = metadata.get_page_table_64()
+
+ assert (
+ forward_batch.seq_lens_cpu is not None
+ and forward_batch.extend_seq_lens_cpu is not None
+ )
+
+ for i in range(forward_batch.batch_size):
+ seq_len = forward_batch.seq_lens_cpu[i].item()
+ assert isinstance(seq_len, int)
+ k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
+ layer_id,
+ seq_len,
+ block_tables[i],
+ )
+ k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
+ layer_id,
+ seq_len,
+ block_tables[i],
+ )
+ extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
+ ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
+ k_fp8_list.append(k_fp8)
+ k_scale_list.append(k_scale)
+ ks_list.append(ks)
+ offset += extend_seq_len
+
+ k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
+ k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
+ kv_fp8 = (k_fp8, k_scale)
+ ks = torch.cat(ks_list, dim=0)
+ seq_lens_expanded = metadata.get_seqlens_expanded()
+ ke = ks + seq_lens_expanded
+
+ logits = deep_gemm_v32.fp8_mqa_logits(
+ q_fp8,
+ kv_fp8,
+ weights,
+ ks,
+ ke,
+ clean_logits=False,
+ )
+
+ assert logits.shape[0] == len(seq_lens_expanded)
+ topk_result = metadata.topk_transform(logits, self.index_topk)
+
+ return topk_result
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ q_lora: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ layer_id: int,
+ ) -> Optional[torch.Tensor]:
+ if TYPE_CHECKING:
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
+
+ metadata = forward_batch.attn_backend.get_indexer_metadata(
+ layer_id, forward_batch
+ )
+
+ enable_dual_stream = (
+ NSA_DUAL_STREAM
+ and self.alt_stream is not None
+ and get_is_capture_mode()
+ and q_lora.shape[0] > 0
+ and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
+ )
+
+ # skip NSA if attention backend choose to skip this batch
+ if metadata is None:
+ return None
+
+ if not NSA_USE_REAL_INDEXER: # temporary
+ return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
+
+ query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
+
+ q_fp8 = query.to(torch.float32)
+ k_fp8 = key.to(torch.float32)
+ q_scale = torch.ones((query.shape[0], 1), dtype=torch.float32, device="cuda")
+ k_scale = torch.ones((key.shape[0], 1), dtype=torch.float32, device="cuda")
+
+ if enable_dual_stream:
+ current_stream = torch.cuda.current_stream()
+ self.alt_stream.wait_stream(current_stream)
+
+ q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
+ with torch.cuda.stream(self.alt_stream):
+ k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
+ current_stream.wait_stream(self.alt_stream)
+ else:
+ q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
+ k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
+
+ # k_fp8: (seq_len, head_dim) fp8_e4m3fn
+ # k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
+ # k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
+ # k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
+ forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
+ layer_id=layer_id,
+ loc=forward_batch.out_cache_loc,
+ index_k=k_fp8,
+ index_k_scale=k_scale,
+ )
+
+ weights = self._get_logits_head_gate(x, q_scale)
+
+ assert forward_batch.seq_lens_cpu is not None
+ if len(forward_batch.seq_lens_cpu) == 0:
+ # this seems b/c max-pad, no worries?
+ # if x.shape[0] != 0:
+ # print(
+ # "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
+ # )
+ return torch.full(
+ (x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
+ )
+
+ if forward_batch.forward_mode.is_decode_or_idle():
+ topk_result = self._get_topk_paged(
+ forward_batch, layer_id, q_fp8, weights, metadata
+ )
+ else:
+ topk_result = self._get_topk_ragged(
+ forward_batch, layer_id, q_fp8, weights, metadata
+ )
+
+ return topk_result
+
+ def forward_cuda(
+ self,
+ x: torch.Tensor,
+ q_lora: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ layer_id: int,
+ ) -> Optional[torch.Tensor]:
+ return self._forward(x, q_lora, positions, forward_batch, layer_id)
+
+ def forward_npu(
+ self,
+ x: torch.Tensor,
+ q_lora: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ layer_id: int,
+ ) -> torch.Tensor:
+ import custom_ops
+ import torch_npu
+
+ from sglang.srt.layers.dp_attention import (
+ get_attention_tp_rank,
+ get_attention_tp_size,
+ )
+ from sglang.srt.utils import get_bool_env_var
+
+ if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
+ actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
+ else:
+ actual_seq_lengths_kv = (
+ forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
+ )
+ enable_index_cp = (
+ get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
+ )
+ is_prefill = forward_batch.forward_mode.is_extend()
+
+ attention_tp_rank = get_attention_tp_rank()
+ attention_tp_size = get_attention_tp_size()
+
+ cos_sin = self.rotary_emb.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
+ sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
+ if is_prefill and enable_index_cp:
+ slice_length = cos.shape[0] // attention_tp_size
+ cos = cos[
+ slice_length
+ * attention_tp_rank : slice_length
+ * (attention_tp_rank + 1)
+ ]
+ sin = sin[
+ slice_length
+ * attention_tp_rank : slice_length
+ * (attention_tp_rank + 1)
+ ]
+
+ slot_mapping = forward_batch.out_cache_loc
+ block_table = forward_batch.attn_backend.forward_metadata.block_tables
+
+ bs = x.shape[0]
+
+ q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
+ q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
+ q_pe, q_nope = torch.split(
+ q,
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
+ dim=-1,
+ ) # [bs, 64, 64 + 64]
+
+ q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
+ q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
+ bs, self.n_heads, self.rope_head_dim
+ ) # [bs, n, d]
+ q = torch.cat([q_pe, q_nope], dim=-1)
+
+ k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
+ k = self.k_norm(k_proj)
+ k_pe, k_nope = torch.split(
+ k,
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
+ dim=-1,
+ ) # [bs, 64 + 64]
+
+ k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
+ k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
+ bs, 1, self.rope_head_dim
+ ) # [bs, 1, d]
+ k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
+
+ if is_prefill and enable_index_cp:
+ k, local_k = (
+ torch.empty(
+ (k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
+ dtype=k.dtype,
+ device=k.device,
+ ),
+ k,
+ )
+ get_attention_tp_group().all_gather_into_tensor(k, local_k)
+
+ forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
+
+ indexer_input = {}
+ if is_prefill:
+ actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
+ actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
+ device=q.device
+ )
+ if enable_index_cp:
+ actual_seq_lengths_q -= bs * attention_tp_rank
+ actual_seq_lengths_q = torch.max(
+ actual_seq_lengths_q,
+ torch.zeros_like(actual_seq_lengths_q).to(
+ device=actual_seq_lengths_q.device
+ ),
+ )
+ actual_seq_lengths_q = torch.min(
+ actual_seq_lengths_q,
+ torch.full(actual_seq_lengths_q.shape, bs).to(
+ device=actual_seq_lengths_q.device
+ ),
+ )
+
+ else:
+ if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
+ actual_seq_lengths_q = torch.tensor(
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
+ )
+ else:
+ actual_seq_lengths_q = (
+ forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
+ )
+
+ past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
+
+ x = x.view(-1, self.hidden_size)
+ weights = self.weights_proj(x)[0]
+ block_table = (
+ block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
+ )
+
+ topk_indices = torch.ops.custom.npu_lightning_indexer(
+ query=q.view(-1, self.n_heads, self.head_dim),
+ key=past_key_states,
+ weights=weights,
+ actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
+ actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
+ block_table=block_table,
+ layout_query="TND",
+ layout_key="PA_BSND",
+ sparse_count=self.index_topk,
+ sparse_mode=3,
+ )
+
+ if is_prefill and enable_index_cp:
+ topk_indices, local_topk_indices = (
+ torch.empty(
+ (
+ topk_indices.shape[0] * attention_tp_size,
+ topk_indices.shape[1],
+ topk_indices.shape[2],
+ ),
+ dtype=topk_indices.dtype,
+ device=topk_indices.device,
+ ),
+ topk_indices,
+ )
+ get_attention_tp_group().all_gather_into_tensor(
+ topk_indices, local_topk_indices
+ )
+
+ return topk_indices
diff --git a/python/sglang/srt/layers/attention/nsa/quant_k_cache.py b/python/sglang/srt/layers/attention/nsa/quant_k_cache.py
new file mode 100644
index 000000000..1c7ae38b5
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/quant_k_cache.py
@@ -0,0 +1,255 @@
+import torch
+import triton
+import triton.language as tl
+
+from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST
+
+
+def quantize_k_cache(cache_k):
+ # TODO upstream can skip concat([k_nope, k_pe]) since we split them here
+ if NSA_QUANT_K_CACHE_FAST:
+ return _quantize_k_cache_fast_wrapped(cache_k)
+ else:
+ return _quantize_k_cache_slow(cache_k)
+
+
+# Copied from original
+def _quantize_k_cache_slow(
+ input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
+ dv: int = 512,
+ tile_size: int = 128,
+) -> torch.Tensor:
+ """
+ Quantize the k-cache
+ Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
+ For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
+ """
+ assert dv % tile_size == 0
+ num_tiles = dv // tile_size
+ num_blocks, block_size, h_k, d = input_k_cache.shape
+ assert h_k == 1
+ input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
+ input_elem_size = input_k_cache.element_size()
+
+ result = torch.empty(
+ (num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
+ dtype=torch.float8_e4m3fn,
+ device=input_k_cache.device,
+ )
+ result_k_nope_part = result[..., :dv]
+ result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
+ result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
+ result_k_rope_part[:] = input_k_cache[..., dv:]
+
+ for tile_idx in range(0, num_tiles):
+ cur_scale_factors_inv = (
+ torch.abs(
+ input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
+ )
+ .max(dim=-1)
+ .values
+ / 448.0
+ ) # [num_blocks, block_size]
+ result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
+
+ cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
+ cur_quantized_nope = (
+ input_k_cache[
+ ..., tile_idx * tile_size : (tile_idx + 1) * tile_size
+ ].float()
+ / cur_scale_factors_inv.float()
+ ).to(torch.float8_e4m3fn)
+ result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
+ cur_quantized_nope
+ )
+
+ result = result.view(num_blocks, block_size, 1, -1)
+ return result
+
+
+def _quantize_k_cache_fast_wrapped(
+ input_k_cache: torch.Tensor,
+ dv: int = 512,
+ tile_size: int = 128,
+) -> torch.Tensor:
+ # TODO the final API may be 2D instead of 4D, thus we convert them here
+ num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape
+ assert dv == 512
+ assert dim_nope_and_rope == 512 + 64
+ assert tile_size == 128
+ input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))
+
+ # TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
+ k_nope = input_k_cache[:, :dv]
+ k_rope = input_k_cache[:, dv:]
+
+ output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)
+
+ return output.view(num_blocks, block_size, 1, -1)
+
+
+def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):
+ """
+ :param k_nope: (num_tokens, dim_nope 512)
+ :param k_rope: (num_tokens, dim_rope 64)
+ """
+
+ assert k_nope.dtype == torch.bfloat16
+ assert k_rope.dtype == torch.bfloat16
+
+ num_tokens, dim_nope = k_nope.shape
+ num_tokens_, dim_rope = k_rope.shape
+ assert num_tokens == num_tokens_
+ assert dim_nope == 512
+ assert dim_rope == 64
+ assert k_nope.dtype == k_rope.dtype
+ num_tiles = dim_nope // group_size
+
+ assert k_nope.stride(1) == 1
+ assert k_rope.stride(1) == 1
+
+ output = torch.empty(
+ (num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),
+ dtype=torch.float8_e4m3fn,
+ device=k_nope.device,
+ )
+ output_nope_q = output[..., :dim_nope]
+ output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)
+ output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)
+
+ num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
+ assert num_blocks_per_token == 5
+
+ assert dim_nope % group_size == 0
+ NUM_NOPE_BLOCKS = dim_nope // group_size
+
+ _quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
+ output_nope_q,
+ output_nope_s,
+ output_rope,
+ k_nope,
+ k_rope,
+ output_nope_q.stride(0),
+ output_nope_s.stride(0),
+ output_rope.stride(0),
+ k_nope.stride(0),
+ k_rope.stride(0),
+ NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
+ GROUP_SIZE=group_size,
+ DIM_NOPE=dim_nope,
+ DIM_ROPE=dim_rope,
+ FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
+ FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
+ )
+
+ return output
+
+
+@triton.jit
+def _quantize_k_cache_fast_kernel(
+ output_nope_q_ptr,
+ output_nope_s_ptr,
+ output_rope_ptr,
+ k_nope_ptr,
+ k_rope_ptr,
+ output_nope_q_stride_0: int,
+ output_nope_s_stride_0: int,
+ output_rope_stride_0: int,
+ k_nope_stride_0: int,
+ k_rope_stride_0: int,
+ NUM_NOPE_BLOCKS: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+ DIM_NOPE: tl.constexpr,
+ DIM_ROPE: tl.constexpr,
+ FP8_MIN: tl.constexpr,
+ FP8_MAX: tl.constexpr,
+):
+ token_id = tl.program_id(0)
+ raw_block_id = tl.program_id(1)
+
+ if raw_block_id < NUM_NOPE_BLOCKS:
+ # a. quant nope
+ effective_block_id = raw_block_id
+
+ offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
+ mask = offs < DIM_NOPE
+ ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs
+
+ y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
+
+ # the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
+ y_s = tl.max(tl.abs(y)) / FP8_MAX
+ y_s_inv = 1.0 / y_s
+ y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(
+ output_nope_q_ptr.dtype.element_ty
+ )
+
+ dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs
+ dst_s_ptr = (
+ output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id
+ )
+
+ tl.store(dst_q_ptr, y_q, mask=mask)
+ tl.store(dst_s_ptr, y_s)
+ else:
+ # b. copy rope
+ effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
+
+ offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
+ mask = offs < DIM_ROPE
+
+ src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs
+ dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs
+
+ data = tl.load(src_ptr, mask=mask)
+ tl.store(dst_ptr, data, mask=mask)
+
+
+if __name__ == "__main__":
+ for num_blocks, block_size in [
+ (1, 1),
+ (10, 64),
+ ]:
+ dim_nope_and_rope = 512 + 64
+
+ input_k_cache = torch.randn(
+ (num_blocks, block_size, 1, dim_nope_and_rope),
+ dtype=torch.bfloat16,
+ device="cuda",
+ )
+ # temp debug
+ # input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
+
+ ref_quant = _quantize_k_cache_slow(input_k_cache)
+ actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
+ # print(f"{input_k_cache=}")
+ # print(f"{ref_quant=}")
+ # print(f"{actual_quant=}")
+ # print(f"{ref_quant == actual_quant=}")
+ # print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
+ # print(f"{ref_quant.view(torch.bfloat16)=}")
+ # print(f"{actual_quant.view(torch.bfloat16)=}")
+ # assert torch.all(ref_quant == actual_quant)
+
+ import dequant_k_cache
+
+ ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
+ ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
+ actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(
+ actual_quant
+ )
+
+ print(f"{ref_ref_dequant=}")
+ print(f"{actual_actual_dequant=}")
+ print(f"{actual_actual_dequant - ref_ref_dequant=}")
+ print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}")
+
+ # TODO too different?
+ torch.testing.assert_close(
+ ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2
+ )
+ torch.testing.assert_close(
+ ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
+ )
+
+ print("Passed")
diff --git a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
new file mode 100644
index 000000000..d2f271e17
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
@@ -0,0 +1,774 @@
+from typing import Optional, Tuple
+
+import tilelang
+import tilelang.language as T
+import torch
+
+tilelang.set_log_level("WARNING")
+
+pass_configs = {
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
+}
+
+BF16 = "bfloat16"
+FP8 = "float8_e4m3"
+FP32 = "float32"
+
+
+def fast_log2_ceil(x):
+ bits_x = T.reinterpret("uint32", x)
+ exp_x = (bits_x >> 23) & 0xFF
+ man_bits = bits_x & ((1 << 23) - 1)
+ return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
+
+
+def fast_pow2(x):
+ bits_x = (x + 127) << 23
+ return T.reinterpret("float32", bits_x)
+
+
+def fast_round_scale(amax, fp8_max_inv):
+ return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
+
+
+@tilelang.jit(pass_configs=pass_configs)
+def act_quant_kernel(
+ N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
+):
+ M = T.symbolic("M")
+ fp8_min = -448.0
+ fp8_max = 448.0
+ fp8_max_inv = 1 / fp8_max
+ num_stages = 0 if round_scale else 2
+ blk_m = 32
+ group_size = 128
+
+ @T.prim_func
+ def act_quant_kernel_(
+ X: T.Tensor[(M, N), in_dtype],
+ Y: T.Tensor[(M, N), out_dtype],
+ S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
+ ):
+ with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
+ pid_m,
+ pid_n,
+ ):
+ x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
+ x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
+ amax_local = T.alloc_fragment((blk_m,), scale_dtype)
+ s_local = T.alloc_fragment((blk_m,), scale_dtype)
+ y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
+ y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
+
+ for _ in T.Pipelined(1, num_stages=num_stages):
+ T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
+ T.copy(x_shared, x_local)
+ T.reduce_absmax(x_local, amax_local, dim=1)
+ for i in T.Parallel(blk_m):
+ amax_local[i] = T.max(amax_local[i], 1e-4)
+ if round_scale:
+ s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
+ else:
+ s_local[i] = amax_local[i] * fp8_max_inv
+ for i, j in T.Parallel(blk_m, group_size):
+ y_local[i, j] = T.clamp(
+ x_local[i, j] / s_local[i], fp8_min, fp8_max
+ )
+ for i in T.Parallel(blk_m):
+ S[pid_m * blk_m + i, pid_n] = s_local[i]
+ T.copy(y_local, y_shared)
+ T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
+
+ return act_quant_kernel_
+
+
+def act_quant(
+ x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Quantizes the input tensor `x` using block-wise quantization.
+
+ Args:
+ x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
+ block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
+ scale_fmt (Optional[str], optional): The format of the scale. Default is None.
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
+ - A tensor of scaling factors with dtype `torch.float32`.
+ """
+ assert x.is_contiguous(), "Input tensor must be contiguous"
+ assert (
+ x.size(-1) % block_size == 0
+ ), f"Last dimension size must be divisible by block_size (block_size={block_size})"
+ N = x.size(-1)
+ y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
+ s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
+ kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
+ kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
+ return y, s
+
+
+@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
+def fp8_index_kernel(h: int, d: int):
+ b = T.symbolic("b")
+ m = T.symbolic("m")
+ n = T.symbolic("n")
+
+ blk_n1 = 512
+ blk_n2 = 128
+
+ @T.prim_func
+ def fp8_index_kernel_(
+ q: T.Tensor[(b, m, h, d), FP8],
+ q_s: T.Tensor[(b, m, h), FP32],
+ k: T.Tensor[(b, n, d), FP8],
+ k_s: T.Tensor[(b, n), FP32],
+ o: T.Tensor[(b, m, n), FP32],
+ ) -> None:
+ with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
+ q_smem = T.alloc_shared((h, d), FP8)
+ T.copy(q[i_b, i_m, 0, 0], q_smem)
+
+ q_s_frag = T.alloc_fragment(h, FP32)
+ T.copy(q_s[i_b, i_m, 0], q_s_frag)
+
+ for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
+ k_smem = T.alloc_shared((blk_n2, d), FP8)
+ T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
+
+ k_s_frag = T.alloc_fragment(blk_n2, FP32)
+ T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
+
+ logits = T.alloc_fragment((blk_n2, h), FP32)
+ T.gemm(
+ k_smem,
+ q_smem,
+ logits,
+ transpose_A=False,
+ transpose_B=True,
+ clear_accum=True,
+ )
+
+ for i_h, i3_n in T.Parallel(h, blk_n2):
+ logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
+
+ logits_sum = T.alloc_fragment(blk_n2, FP32)
+ T.reduce_sum(logits, logits_sum, dim=1)
+
+ for i3_n in T.Parallel(blk_n2):
+ logits_sum[i3_n] *= k_s_frag[i3_n]
+
+ T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
+
+ return fp8_index_kernel_
+
+
+def fp8_index(
+ q: torch.Tensor,
+ q_s: torch.Tensor,
+ k: torch.Tensor,
+ k_s: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Perform index score using FP8 precision.
+
+ Args:
+ q (torch.Tensor): The Q tensor, must be contiguous.
+ q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
+ k (torch.Tensor): The K tensor, must be contiguous.
+ k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
+
+ fp8 q @ fp8 k -> fp32 logits
+ relu(fp32 logits) * q_s (weights) -> fp32 logits
+ fp32 logits -> fp32 logits_sum
+ fp32 logits_sum * k_s (e8m0) -> fp32 index_score
+ """
+ return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
+
+
+@tilelang.jit(
+ out_idx=[-1],
+ pass_configs={
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ },
+)
+def sparse_attention_fwd_kernel_v1(
+ num_heads,
+ dim,
+ tail_dim,
+ topk,
+ *,
+ kv_group=1,
+ sm_scale=None,
+ is_causal=True,
+ block_I=64,
+ num_stages=2,
+ threads=256,
+):
+ assert dim == tilelang.math.next_power_of_2(
+ dim
+ ), f"haven't check padding correctness yet, dim={dim}"
+ assert tail_dim == tilelang.math.next_power_of_2(
+ tail_dim
+ ), f"haven't check padding correctness yet, dim={tail_dim}"
+ assert is_causal == True, "non-casual is not supported"
+ assert (
+ topk % block_I == 0
+ ), "otherwise will load some index=0 thus causing wrong kv to be loaded"
+ if sm_scale is None:
+ sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
+ else:
+ sm_scale = sm_scale * 1.44269504 # log2(e)
+
+ batch = T.symbolic("batch")
+ seq_len = T.symbolic("seq_len")
+ seq_len_kv = T.symbolic("seq_len_kv")
+
+ head_kv = num_heads // kv_group
+ q_shape = [batch, seq_len, num_heads, dim + tail_dim]
+ kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
+ o_shape = [batch, seq_len, num_heads, dim]
+ indices_shape = [batch, seq_len, kv_group, topk]
+ indices_dtype = "int32"
+ dtype = "bfloat16"
+ accum_dtype = "float"
+
+ H = head_kv
+ padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
+ if padded_H != H:
+ assert kv_group == 1
+ BI = block_I
+ NI = tilelang.cdiv(topk, block_I)
+ D = dim
+ D_tail = tail_dim
+
+ if head_kv > 64:
+ assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
+ REPLICATE_H = head_kv // 64
+ else:
+ REPLICATE_H = 1
+
+ H_per_block = padded_H if REPLICATE_H == 1 else 64
+
+ @T.prim_func
+ def main(
+ Q: T.Tensor(q_shape, dtype), # type: ignore
+ KV: T.Tensor(kv_shape, dtype), # type: ignore
+ Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
+ Output: T.Tensor(o_shape, dtype), # type: ignore
+ ):
+ with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
+ bx,
+ by,
+ bz,
+ ):
+ Q_shared = T.alloc_shared([H_per_block, D], dtype)
+ Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
+ KV_shared = T.alloc_shared([BI, D], dtype)
+ K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
+ O_shared = T.alloc_shared([H_per_block, D], dtype)
+ mask = T.alloc_fragment([BI], "bool")
+
+ acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
+ acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
+ S_shared = T.alloc_shared([H_per_block, BI], dtype)
+ sumexp = T.alloc_fragment([H_per_block], accum_dtype)
+ sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
+ alpha = T.alloc_fragment([H_per_block], accum_dtype)
+ m_i = T.alloc_fragment([H_per_block], accum_dtype)
+ m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
+
+ T.fill(acc_o, 0)
+ T.fill(sumexp, 0)
+ T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
+
+ b_i, g_i = by, bz
+ s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
+ q_i = s_i
+ max_kv_i = q_i
+
+ H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
+ H1 = H0 + H_per_block
+
+ T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
+ T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
+
+ for i_i in T.Pipelined(NI, num_stages=num_stages):
+
+ for bi_i in T.Parallel(BI):
+ mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0
+
+ for bi_i, d_i in T.Parallel(BI, D):
+ KV_shared[bi_i, d_i] = KV[
+ b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i
+ ]
+ for bi_i, d_i in T.Parallel(BI, D_tail):
+ K_tail_shared[bi_i, d_i] = KV[
+ b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i
+ ]
+
+ for h_i, bi_i in T.Parallel(H_per_block, BI):
+ acc_s[h_i, bi_i] = T.if_then_else(
+ mask[bi_i], 0, -T.infinity(acc_s.dtype)
+ )
+ T.gemm(
+ Q_shared,
+ KV_shared,
+ acc_s,
+ transpose_B=True,
+ policy=T.GemmWarpPolicy.FullCol,
+ )
+ T.gemm(
+ Q_tail_shared,
+ K_tail_shared,
+ acc_s,
+ transpose_B=True,
+ policy=T.GemmWarpPolicy.FullCol,
+ )
+ T.copy(m_i, m_i_prev)
+ T.reduce_max(acc_s, m_i, dim=1, clear=False)
+ for h_i in T.Parallel(H_per_block):
+ alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
+ for h_i, bi_i in T.Parallel(H_per_block, BI):
+ acc_s[h_i, bi_i] = T.exp2(
+ acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
+ )
+ T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
+ for h_i in T.Parallel(H_per_block):
+ sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
+ for h_i, d_i in T.Parallel(H_per_block, D):
+ acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
+
+ T.copy(acc_s, S_shared)
+ T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
+
+ # Rescale
+ for h_i, d_i in T.Parallel(H_per_block, D):
+ acc_o[h_i, d_i] /= sumexp[h_i]
+ for h_i in T.Parallel(H_per_block):
+ sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
+
+ T.copy(acc_o, O_shared)
+ T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
+
+ return main
+
+
+@tilelang.jit(
+ out_idx=[-1],
+ compile_flags=[
+ "-O3",
+ "-Wno-deprecated-declarations",
+ "-U__CUDA_NO_HALF_OPERATORS__",
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
+ "-U__CUDA_NO_HALF2_OPERATORS__",
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+ "--expt-relaxed-constexpr",
+ "--expt-extended-lambda",
+ "--ptxas-options=-v,--register-usage-level=10",
+ "-DNDEBUG",
+ ],
+) # type: ignore
+def sparse_attention_fwd_kernel_v2(
+ num_heads: int,
+ dim: int,
+ tail_dim: int,
+ topk: int,
+ *,
+ kv_group: int = 1,
+ sm_scale: Optional[float] = None,
+ block_I: int = 64,
+):
+ assert dim == tilelang.math.next_power_of_2(
+ dim
+ ), f"haven't check padding correctness yet, dim={dim}"
+ assert tail_dim == tilelang.math.next_power_of_2(
+ tail_dim
+ ), f"haven't check padding correctness yet, dim={tail_dim}"
+ assert (
+ topk % block_I == 0
+ ), "otherwise will load some index=0 thus causing wrong kv to be loaded"
+ if sm_scale is None:
+ sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
+ else:
+ sm_scale = sm_scale * 1.44269504 # log2(e)
+ threads = 384
+
+ batch = T.symbolic("batch")
+ qo_len = T.symbolic("seq_len")
+ num_pages = T.symbolic("num_pages")
+
+ q_shape = [batch, qo_len, num_heads, dim + tail_dim]
+ kv_shape = [batch, num_pages, kv_group, dim + tail_dim]
+ o_shape = [batch, qo_len, num_heads, dim]
+ indices_shape = [batch, qo_len, kv_group, topk]
+
+ indices_dtype = "int32"
+ dtype = "bfloat16"
+ accum_dtype = "float"
+
+ H = num_heads
+ padded_H = max(tilelang.math.next_power_of_2(num_heads), 16)
+ if padded_H != H:
+ assert kv_group == 1
+ BI = block_I
+ NI = tilelang.cdiv(topk, block_I)
+ assert NI % 2 == 0, "NI should be a multiple of 2"
+ D = dim
+ D_tail = tail_dim
+ if num_heads > 64:
+ assert num_heads % 64 == 0, "head_kv should be a multiple of 64"
+ REPLICATE_H = num_heads // 64
+ else:
+ REPLICATE_H = 1
+
+ H_per_block = padded_H if REPLICATE_H == 1 else 64
+
+ @T.prim_func
+ def main(
+ Q: T.Tensor(q_shape, dtype), # type: ignore
+ KV: T.Tensor(kv_shape, dtype), # type: ignore
+ Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
+ Output: T.Tensor(o_shape, dtype), # type: ignore
+ ):
+ """
+ Q: [b, qo_len, H, D + D_tail] (bfloat16)
+ KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)
+ Indices: [b, qo_len, kv_group, topk] (int32)
+ """
+
+ with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz): # type: ignore
+ Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
+ Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
+ Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
+ KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)
+ KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)
+ KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)
+ KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)
+ K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)
+ K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)
+ O_shared_l = Q_shared_l
+ O_shared_r = Q_shared_r
+ is_kv_valid_0 = T.alloc_shared([BI], "bool", scope="shared")
+ is_kv_valid_1 = T.alloc_shared([BI], "bool", scope="shared")
+
+ acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
+ acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
+ acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
+ S_shared = T.alloc_shared([H_per_block, BI], dtype)
+ sumexp = T.alloc_fragment([H_per_block], accum_dtype)
+ sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)
+ sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
+ alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared")
+ alpha_local = T.alloc_fragment([H_per_block], accum_dtype)
+ m_i = T.alloc_fragment([H_per_block], accum_dtype)
+ m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
+ indices_local = T.alloc_local([1], indices_dtype)
+ indices_tmp = T.alloc_local([1], indices_dtype)
+
+ bar_q = T.alloc_barrier(arrive_count=384)
+ bar_k_0_ready = T.alloc_barrier(arrive_count=128)
+ bar_k_1_ready = T.alloc_barrier(arrive_count=128)
+ bar_k_0_free = T.alloc_barrier(arrive_count=256)
+ bar_k_1_free = T.alloc_barrier(arrive_count=256)
+ bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
+ bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
+
+ bar_0_128 = T.alloc_barrier(arrive_count=128)
+ bar_1_128 = T.alloc_barrier(arrive_count=128)
+ bar_2_128 = T.alloc_barrier(arrive_count=128)
+ bar_final = T.alloc_barrier(arrive_count=128)
+
+ b_i, g_i = by, bz
+ s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H
+
+ H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
+ H1 = H0 + H_per_block
+
+ tx = T.get_thread_binding()
+
+ T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
+ T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
+ T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
+ T.barrier_arrive(bar_q)
+
+ if tx < 128:
+ T.set_max_nreg(240, 1)
+ T.fill(sumexp, 0)
+ T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
+ T.fill(acc_o_l, 0)
+ T.barrier_wait(bar_q, 0)
+
+ for i_i in T.serial(T.ceildiv(NI, 2)):
+ # Buffer 0
+ # with sync_at(bar_0_128, 0):
+ T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
+ T.barrier_arrive(bar_0_128)
+ T.barrier_wait(bar_0_128, 0)
+
+ for h_i, bi_i in T.Parallel(H_per_block, BI):
+ acc_s[h_i, bi_i] = T.if_then_else(
+ is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype)
+ )
+ T.gemm(
+ Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1
+ )
+ T.gemm(
+ Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1
+ )
+ T.gemm(
+ Q_tail_shared,
+ K_tail_shared_0,
+ acc_s,
+ transpose_B=True,
+ wg_wait=-1,
+ )
+
+ T.wait_wgmma(0)
+
+ if i_i != 0:
+ T.barrier_arrive(bar_sScale_and_sS_free)
+ T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
+
+ T.copy(m_i, m_i_prev)
+ T.reduce_max(acc_s, m_i, dim=1, clear=False)
+ for h_i in T.Parallel(H_per_block):
+ alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
+ for h_i, bi_i in T.Parallel(H_per_block, BI):
+ acc_s[h_i, bi_i] = T.exp2(
+ acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
+ )
+ T.reduce_sum(
+ acc_s, sumexp_i, dim=1
+ ) # is this a accumulate operator?
+ for h_i in T.Parallel(H_per_block):
+ sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
+ for h_i, d_i in T.Parallel(H_per_block, D // 2):
+ acc_o_l[h_i, d_i] *= alpha_local[h_i]
+ T.copy(alpha_local, alpha_shared)
+
+ T.copy(acc_s, S_shared)
+ T.gemm(S_shared, KV_shared_0_l, acc_o_l)
+
+ T.barrier_arrive(bar_sScale_and_sS_ready)
+ T.barrier_arrive(bar_k_0_free[0])
+
+ # Buffer 1
+ T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
+ T.barrier_arrive(bar_0_128)
+ T.barrier_wait(bar_0_128, 1)
+
+ for h_i, bi_i in T.Parallel(H_per_block, BI):
+ acc_s[h_i, bi_i] = T.if_then_else(
+ is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype)
+ )
+ T.gemm(
+ Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1
+ )
+ T.gemm(
+ Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1
+ )
+ T.gemm(
+ Q_tail_shared,
+ K_tail_shared_1,
+ acc_s,
+ transpose_B=True,
+ wg_wait=-1,
+ )
+
+ T.wait_wgmma(0)
+
+ T.barrier_arrive(bar_sScale_and_sS_free)
+ T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
+
+ T.copy(m_i, m_i_prev)
+ T.reduce_max(acc_s, m_i, dim=1, clear=False)
+ for h_i in T.Parallel(H_per_block):
+ alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
+ for h_i, bi_i in T.Parallel(H_per_block, BI):
+ acc_s[h_i, bi_i] = T.exp2(
+ acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
+ )
+ T.reduce_sum(
+ acc_s, sumexp_i, dim=1
+ ) # is this a accumulate operator?
+ for h_i in T.Parallel(H_per_block):
+ sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
+ for h_i, d_i in T.Parallel(H_per_block, D // 2):
+ acc_o_l[h_i, d_i] *= alpha_local[h_i]
+ T.copy(alpha_local, alpha_shared)
+
+ T.copy(acc_s, S_shared)
+ T.gemm(S_shared, KV_shared_1_l, acc_o_l)
+
+ T.barrier_arrive(bar_sScale_and_sS_ready)
+ T.barrier_arrive(bar_k_1_free[0])
+
+ # Rescale
+ for h_i in T.Parallel(H_per_block):
+ sum_exp_shared[h_i] = sumexp[h_i]
+ T.barrier_arrive(bar_final)
+ for h_i, d_i in T.Parallel(H_per_block, D // 2):
+ acc_o_l[h_i, d_i] /= sumexp[h_i]
+ for h_i in T.Parallel(H_per_block):
+ sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
+ T.copy(acc_o_l, O_shared_l)
+ T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
+ elif tx >= 128 and tx < 256:
+ # T.set_max_nreg(168, 1)
+ T.fill(acc_o_r, 0)
+ for i_i in T.serial(T.ceildiv(NI, 2)):
+ # Buffer 0
+ T.barrier_arrive(bar_sScale_and_sS_ready)
+ T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
+ T.barrier_arrive(bar_1_128)
+ T.barrier_wait(bar_1_128, 0)
+ for h_i, d_i in T.Parallel(H_per_block, D // 2):
+ acc_o_r[h_i, d_i] *= alpha_shared[h_i]
+ T.gemm(S_shared, KV_shared_0_r, acc_o_r)
+ T.barrier_arrive(bar_k_0_free[0])
+ T.barrier_arrive(bar_sScale_and_sS_free)
+
+ # Buffer 1
+ T.barrier_arrive(bar_sScale_and_sS_ready)
+ T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
+ T.barrier_arrive(bar_1_128)
+ T.barrier_wait(bar_1_128, 1)
+ for h_i, d_i in T.Parallel(H_per_block, D // 2):
+ acc_o_r[h_i, d_i] *= alpha_shared[h_i]
+ T.gemm(S_shared, KV_shared_1_r, acc_o_r)
+ T.barrier_arrive(bar_k_1_free[0])
+ if i_i != T.ceildiv(NI, 2) - 1:
+ T.barrier_arrive(bar_sScale_and_sS_free)
+
+ # Rescale
+ T.barrier_wait(bar_final, 0)
+ for h_i, d_i in T.Parallel(H_per_block, D // 2):
+ acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
+
+ T.copy(acc_o_r, O_shared_r)
+ T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
+ elif tx >= 256:
+ # producer
+ T.set_max_nreg(80, 0)
+ indices_local[0] = 0
+ for i_i in T.serial(T.ceildiv(NI, 2)):
+ # Buffer 0
+ T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
+ T.barrier_arrive(bar_2_128)
+ T.barrier_wait(bar_2_128, 0)
+
+ for r in T.serial(4):
+ indices_tmp[0] = Indices[
+ b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8
+ ]
+ is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
+ if is_kv_valid_0[r * 16 + (tx - 256) // 8]:
+ indices_local[0] = indices_tmp[0]
+
+ with T.attr("default", "async_scope", 1): # type: ignore
+ for u in T.serial(4):
+ for v in T.vectorized(8):
+ KV_shared_0_l[
+ r * 16 + (tx - 256) // 8,
+ 64 * u + (tx - 256) % 8 * 8 + v,
+ ] = KV[
+ b_i,
+ indices_local[0],
+ g_i,
+ 64 * u + (tx - 256) % 8 * 8 + v,
+ ]
+ KV_shared_0_r[
+ r * 16 + (tx - 256) // 8,
+ 64 * u + (tx - 256) % 8 * 8 + v,
+ ] = KV[
+ b_i,
+ indices_local[0],
+ g_i,
+ D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
+ ]
+ with T.attr("default", "async_scope", 1): # type: ignore
+ for v in T.vectorized(8):
+ K_tail_shared_0[
+ r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
+ ] = KV[
+ b_i,
+ indices_local[0],
+ g_i,
+ D + (tx - 256) % 8 * 8 + v,
+ ]
+
+ T.cp_async_barrier_noinc(bar_k_0_ready[0])
+
+ # Buffer 1
+ T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
+ T.barrier_arrive(bar_2_128)
+ T.barrier_wait(bar_2_128, 1)
+
+ for r in T.serial(4):
+ indices_tmp[0] = Indices[
+ b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8
+ ]
+ is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
+ if is_kv_valid_1[r * 16 + (tx - 256) // 8]:
+ indices_local[0] = indices_tmp[0]
+
+ with T.attr("default", "async_scope", 1): # type: ignore
+ for u in T.serial(4):
+ for v in T.vectorized(8):
+ KV_shared_1_l[
+ r * 16 + (tx - 256) // 8,
+ 64 * u + (tx - 256) % 8 * 8 + v,
+ ] = KV[
+ b_i,
+ indices_local[0],
+ g_i,
+ 64 * u + (tx - 256) % 8 * 8 + v,
+ ]
+ KV_shared_1_r[
+ r * 16 + (tx - 256) // 8,
+ 64 * u + (tx - 256) % 8 * 8 + v,
+ ] = KV[
+ b_i,
+ indices_local[0],
+ g_i,
+ D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
+ ]
+ with T.attr("default", "async_scope", 1): # type: ignore
+ for v in T.vectorized(8):
+ K_tail_shared_1[
+ r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
+ ] = KV[
+ b_i,
+ indices_local[0],
+ g_i,
+ D + (tx - 256) % 8 * 8 + v,
+ ]
+
+ T.cp_async_barrier_noinc(bar_k_1_ready[0])
+
+ return main
+
+
+def tilelang_sparse_fwd(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ indices: torch.Tensor,
+ sm_scale: float,
+ d_v: int = 512,
+) -> torch.Tensor:
+ assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3
+ num_heads = q.shape[1]
+ dim = q.shape[2]
+ tail_dim = dim - d_v
+ topk = indices.shape[-1]
+ assert topk == 2048
+ # NOTE(dark): v2 offers better performance than v1
+ kernel = sparse_attention_fwd_kernel_v2(
+ num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
+ )
+ return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
diff --git a/python/sglang/srt/layers/attention/nsa/topk.py b/python/sglang/srt/layers/attention/nsa/topk.py
new file mode 100644
index 000000000..684950621
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/topk.py
@@ -0,0 +1,65 @@
+import torch
+
+from sglang.srt.utils import align
+
+# NOTE(dark): flashmla P requires `params.topk % (2*B_TOPK) == 0`,
+# where `B_TOPK=64`. So we align to 128 by default.
+
+_TOPK_ALIGNMENT = 128
+
+
+# TODO(dark): maybe this torch_op can support torch.compile
+def _fast_topk_torch(
+ input: torch.Tensor, seq_lens: torch.Tensor, topk: int, alignment: int
+) -> torch.Tensor:
+ # Fallback to torch.topk
+ bs, max_seq_len = input.shape
+ assert len(seq_lens) == bs
+ # set those out-of-bound input to -inf
+ padded_max_seq_len = align(max_seq_len, alignment)
+ positions = torch.arange(
+ padded_max_seq_len, device=input.device, dtype=seq_lens.dtype
+ )
+ positions = positions.unsqueeze(0).expand(bs, -1)
+ mask = positions >= seq_lens.unsqueeze(1)
+
+ # NOTE(dark): just return all valid indices as an optimization
+ if padded_max_seq_len <= topk:
+ return positions.masked_fill(mask, -1)
+
+ assert topk % alignment == 0
+
+ # in-place operation: mask invalid inputs to -inf
+ input = input.masked_fill_(mask[:, :max_seq_len], float("-inf"))
+ result = input.topk(topk, dim=-1, sorted=True)
+ return result.indices.masked_fill_(mask[:, :topk], -1)
+
+
+def fast_topk_impl(
+ input: torch.Tensor,
+ seq_lens: torch.Tensor,
+ topk: int,
+ alignment: int = _TOPK_ALIGNMENT,
+) -> torch.Tensor:
+ return _fast_topk_torch(input, seq_lens, topk, alignment)
+
+
+def fast_topk_transform_fused_cuda(
+ input: torch.Tensor,
+ seq_lens: torch.Tensor,
+ topk: int,
+ dst_page_table: torch.Tensor,
+ src_page_table: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ alignment: int = _TOPK_ALIGNMENT,
+) -> torch.Tensor:
+ from sglang.srt.layers.attention.nsa.cuda import fast_topk_transform
+
+ assert topk == 2048 and topk % alignment == 0
+ return fast_topk_transform(
+ score=input,
+ lengths=seq_lens,
+ dst_page_table=dst_page_table,
+ src_page_table=src_page_table,
+ cu_seqlens=cu_seqlens_q,
+ )
diff --git a/python/sglang/srt/layers/attention/nsa/transform_index.py b/python/sglang/srt/layers/attention/nsa/transform_index.py
new file mode 100644
index 000000000..442dd113d
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/transform_index.py
@@ -0,0 +1,144 @@
+from typing import List, Optional
+
+import torch
+import triton
+import triton.language as tl
+
+
+def transform_index_page_table_prefill(**kwargs):
+ return transform_index_page_table_prefill_ref(**kwargs)
+
+
+def transform_index_page_table_decode(**kwargs):
+ return transform_index_page_table_decode_ref(**kwargs)
+
+
+@triton.jit
+def transform_index_page_table_decode_kernel(
+ page_table_ptr: torch.Tensor,
+ topk_indices_ptr: torch.Tensor,
+ result_ptr: torch.Tensor,
+ page_size: tl.constexpr,
+ max_seqlen_k: tl.constexpr,
+):
+ TOPK: tl.constexpr = 2048
+ req_id = tl.program_id(0)
+ page_table_ptr = page_table_ptr + req_id * max_seqlen_k
+ topk_indices_ptr = topk_indices_ptr + req_id * TOPK
+ result_ptr = result_ptr + req_id * TOPK
+
+ offset = tl.arange(0, TOPK) # topk should be 2048
+ loaded_topk_indices = tl.load(topk_indices_ptr + offset)
+ mask = loaded_topk_indices >= 0
+ loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
+ tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
+ tl.store(result_ptr + offset, -1, mask=~mask)
+
+
+def transform_index_page_table_decode_fast(
+ page_table: torch.Tensor,
+ topk_indices: torch.Tensor,
+ result: Optional[torch.Tensor] = None,
+ page_size: int = 1,
+) -> torch.Tensor:
+ """
+ Transform the page table according to topk indices for sparse topk attention.
+ Args:
+ page_table: [qo_len, max_seqlen_k], the original page table
+ topk_indices: [qo_len, topk], the topk indices for each query position
+ Returns:
+ transformed_page_table: [qo_len, topk], the transformed page table
+ For out-of-bound indices in topk_indices, this should be filled with -1.
+ """
+ assert page_size == 1
+ assert page_table.shape[0] == topk_indices.shape[0]
+ assert topk_indices.shape[1] == 2048
+ qo_len = topk_indices.shape[0]
+ max_seqlen_k = page_table.shape[1]
+ if result is None:
+ result = torch.empty_like(topk_indices, dtype=torch.int32)
+ # Launch triton kernel
+ grid = (qo_len,)
+ transform_index_page_table_decode_kernel[grid](
+ page_table,
+ topk_indices,
+ result,
+ page_size,
+ max_seqlen_k=max_seqlen_k,
+ )
+ return result
+
+
+def transform_index_page_table_prefill_fast(
+ page_table: torch.Tensor,
+ topk_indices: torch.Tensor,
+ extend_lens_cpu: List[int],
+ page_size: int = 1,
+) -> torch.Tensor:
+ # TODO(baizhou): can be implemented with another triton kernel
+ assert page_size == 1
+ result = torch.empty_like(topk_indices, dtype=torch.int32)
+ assert len(extend_lens_cpu) == page_table.shape[0]
+ offset = 0
+ for i, l in enumerate(extend_lens_cpu):
+ transform_index_page_table_decode_fast(
+ page_table[i].unsqueeze(0).expand(l, -1),
+ topk_indices[offset : offset + l],
+ result=result[offset : offset + l],
+ )
+ offset += l
+ assert offset == topk_indices.shape[0]
+ return result
+
+
+def transform_index_page_table_decode_ref(
+ page_table: torch.Tensor,
+ topk_indices: torch.Tensor,
+ result: Optional[torch.Tensor] = None,
+ page_size: int = 1,
+) -> torch.Tensor:
+ assert page_size == 1
+ assert page_table.shape[0] == topk_indices.shape[0]
+ if result is None:
+ result = torch.empty_like(topk_indices, dtype=torch.int32)
+ assert result.shape == topk_indices.shape
+ torch.gather(
+ page_table,
+ dim=1,
+ index=topk_indices.clamp(min=0),
+ out=result,
+ )
+ result[topk_indices < 0] = -1
+ return result
+
+
+def transform_index_page_table_prefill_ref(
+ page_table: torch.Tensor,
+ topk_indices: torch.Tensor,
+ extend_lens_cpu: List[int],
+ page_size: int = 1,
+) -> torch.Tensor:
+ assert page_size == 1
+ result = torch.empty_like(topk_indices, dtype=torch.int32)
+ assert len(extend_lens_cpu) == page_table.shape[0]
+ offset = 0
+ for i, l in enumerate(extend_lens_cpu):
+ transform_index_page_table_decode_ref(
+ page_table[i].unsqueeze(0).expand(l, -1),
+ topk_indices[offset : offset + l],
+ result=result[offset : offset + l],
+ )
+ offset += l
+ assert offset == topk_indices.shape[0]
+ return result
+
+
+if __name__ == "__main__":
+ bs, topk, max_seqlen = 10, 2048, 3000
+ page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda")
+ topk_indices = torch.full((bs, topk), -1, device="cuda")
+ topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)
+ ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)
+ result = transform_index_page_table_decode_fast(page_table, topk_indices)
+ assert torch.all(result == ref_result)
+ print("Passed")
diff --git a/python/sglang/srt/layers/attention/nsa/unit_test/get_logits_ut.py b/python/sglang/srt/layers/attention/nsa/unit_test/get_logits_ut.py
new file mode 100644
index 000000000..17edf8a4f
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/unit_test/get_logits_ut.py
@@ -0,0 +1,57 @@
+import torch
+import torch.nn as nn
+
+
+class DummyModel(nn.Module):
+ def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
+ super().__init__()
+ self.weights_proj = nn.Linear(d_in, 1024)
+ self.n_heads = n_heads
+ self.softmax_scale = softmax_scale
+
+ def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
+ weights = self.weights_proj(x)
+ weights = weights * self.n_heads**-0.5
+ q_scale = q_scale.unsqueeze(1) # (B,1,1)
+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
+ return weights
+
+ def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
+ weights = self.weights_proj(x)
+ q_scale = q_scale.unsqueeze(1) # (B,1,1)
+ scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
+ weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
+ return weights
+
+
+def main():
+ torch.manual_seed(0)
+ model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
+ x = torch.randn(128, 2048) # batch=128, d_in=2048
+ q_scale = torch.randn(128, 1)
+
+ import time
+
+ start = time.time()
+ for _ in range(1000):
+ out_orig = model._get_logits_head_gate_orig(x, q_scale)
+ print("Original version time:", time.time() - start)
+
+ start = time.time()
+ for _ in range(1000):
+ out_opt = model._get_logits_head_gate_opt(x, q_scale)
+ print("Optimized version time:", time.time() - start)
+
+ print("Difference:", (out_orig - out_opt).abs().max().item())
+ assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
+
+
+if __name__ == "__main__":
+ main()
+
+
+"""
+Original version time: 0.49235057830810547
+Optimized version time: 0.4087331295013428
+Difference: 1.4901161193847656e-08
+"""
diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py
new file mode 100644
index 000000000..cdc812382
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa/utils.py
@@ -0,0 +1,32 @@
+# temp NSA debugging environ
+from sglang.srt.utils import get_bool_env_var
+
+NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
+NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
+NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
+
+NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var(
+ "SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true"
+)
+NSA_KV_CACHE_STORE_FP8 = get_bool_env_var("SGLANG_NSA_KV_CACHE_STORE_FP8", "false")
+NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "false")
+NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "false")
+
+
+def _print_bool_env_vars():
+ msg = ""
+ for k, v in globals().items():
+ if k.startswith("NSA_") and isinstance(v, bool):
+ msg += f"{k}={v} "
+ print(msg, flush=True)
+
+
+_print_bool_env_vars()
+
+
+if not NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8:
+ assert not NSA_KV_CACHE_STORE_FP8
+
+
+def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):
+ return original_seq_lens.clamp(max=nsa_index_topk)
diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py
new file mode 100644
index 000000000..54e62c94d
--- /dev/null
+++ b/python/sglang/srt/layers/attention/nsa_backend.py
@@ -0,0 +1,870 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ TypeAlias,
+ Union,
+ override,
+)
+
+import torch
+from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
+
+from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
+from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
+from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache
+from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
+from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
+from sglang.srt.layers.attention.nsa.topk import (
+ fast_topk_impl,
+ fast_topk_transform_fused_cuda,
+)
+from sglang.srt.layers.attention.nsa.transform_index import (
+ transform_index_page_table_decode,
+ transform_index_page_table_prefill,
+)
+from sglang.srt.layers.attention.nsa.utils import (
+ NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
+ NSA_FUSE_TOPK,
+ NSA_KV_CACHE_STORE_FP8,
+ compute_nsa_seqlens,
+)
+from sglang.srt.layers.dp_attention import get_attention_tp_size
+from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
+from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
+from sglang.srt.two_batch_overlap import global_server_args_dict
+
+if TYPE_CHECKING:
+ from sglang.srt.layers.radix_attention import RadixAttention
+ from sglang.srt.model_executor.model_runner import ModelRunner
+
+
+@dataclass(frozen=True)
+class NSAFlashMLAMetadata:
+ """Metadata only needed by FlashMLA"""
+
+ flashmla_metadata: torch.Tensor
+ num_splits: torch.Tensor
+
+ def slice(self, sli):
+ return NSAFlashMLAMetadata(
+ flashmla_metadata=self.flashmla_metadata,
+ num_splits=self.num_splits[sli],
+ )
+
+ def copy_(self, other: "NSAFlashMLAMetadata"):
+ self.flashmla_metadata.copy_(other.flashmla_metadata)
+ self.num_splits.copy_(other.num_splits)
+
+
+@dataclass(frozen=True)
+class NSAMetadata:
+ page_size: int
+
+ # Sequence lengths for the forward batch
+ cache_seqlens_int32: torch.Tensor
+ # Maximum sequence length for query
+ max_seq_len_q: int
+ # Maximum sequence length for key
+ max_seq_len_k: int
+ # Cumulative sequence lengths for query
+ cu_seqlens_q: torch.Tensor
+ # Cumulative sequence lengths for key
+ cu_seqlens_k: torch.Tensor
+ # Page table, the index of KV Cache Tables/Blocks
+ # this table is always with page_size = 1
+ page_table_1: torch.Tensor
+
+ # NOTE(dark): This will property be used in:
+ # 1. dense decode/prefill, we use paged flash attention, need real_page_table
+ # 2. sparse decode/prefill, indexer need real_page_table to compute the score
+ real_page_table: torch.Tensor
+
+ # NSA metadata (nsa prefill are expanded)
+ nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
+ nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
+ nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
+ nsa_extend_seq_lens_list: List[int]
+ nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
+ nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
+
+ flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
+
+
+@dataclass(frozen=True)
+class NSAIndexerMetadata(BaseIndexerMetadata):
+ attn_metadata: NSAMetadata
+
+ @override
+ def get_seqlens_int32(self) -> torch.Tensor:
+ return self.attn_metadata.cache_seqlens_int32
+
+ @override
+ def get_page_table_64(self) -> torch.Tensor:
+ return self.attn_metadata.real_page_table
+
+ @override
+ def get_seqlens_expanded(self) -> torch.Tensor:
+ return self.attn_metadata.nsa_seqlens_expanded
+
+ @override
+ def topk_transform(
+ self,
+ logits: torch.Tensor,
+ topk: int,
+ ) -> torch.Tensor:
+ if not NSA_FUSE_TOPK:
+ return fast_topk_impl(logits, self.get_seqlens_expanded(), topk)
+
+ # NOTE(dark): if fused, we return a transformed page table directly
+ dst_page_table = torch.empty(
+ (logits.shape[0], topk), dtype=torch.int32, device=logits.device
+ )
+ fast_topk_transform_fused_cuda(
+ input=logits,
+ seq_lens=self.get_seqlens_expanded(),
+ topk=topk,
+ dst_page_table=dst_page_table,
+ src_page_table=self.attn_metadata.page_table_1,
+ cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
+ )
+ return dst_page_table
+
+
+def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
+ assert seqlens.dtype == torch.int32 and seqlens.is_cuda
+ return torch.nn.functional.pad(
+ torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
+ )
+
+
+_NSA_IMPL_T: TypeAlias = Literal[
+ "flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
+]
+
+NSA_PREFILL_IMPL: _NSA_IMPL_T
+NSA_DECODE_IMPL: _NSA_IMPL_T
+
+
+class NativeSparseAttnBackend(AttentionBackend):
+ def __init__(self, model_runner: ModelRunner):
+ super().__init__()
+ self.forward_metadata: NSAMetadata
+ self.device = model_runner.device
+ assert isinstance(model_runner.page_size, int)
+ self.real_page_size = model_runner.page_size
+ self.num_splits = (
+ 1 if model_runner.server_args.enable_deterministic_inference else 0
+ )
+ self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
+ assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
+ self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
+ self.max_context_len = model_runner.model_config.context_len
+ self.num_q_heads = (
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
+ )
+ self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
+
+ assert model_runner.req_to_token_pool is not None
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
+
+ global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
+ NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
+ NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
+
+ self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
+
+ def get_device_int32_arange(self, l: int) -> torch.Tensor:
+ if l > len(self._arange_buf):
+ next_pow_of_2 = 1 << (l - 1).bit_length()
+ self._arange_buf = torch.arange(
+ next_pow_of_2, device=self.device, dtype=torch.int32
+ )
+ return self._arange_buf[:l]
+
+ def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
+ page_size = self.real_page_size
+ if page_size == 1:
+ return page_table
+ max_seqlen_k = page_table.shape[1]
+ strided_indices = torch.arange(
+ 0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
+ )
+ return page_table[:, strided_indices] // page_size
+
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
+ """Init the metadata for a forward pass."""
+ batch_size = forward_batch.batch_size
+ device = forward_batch.seq_lens.device
+
+ assert (
+ forward_batch.spec_info is None
+ ), "Spec decoding is not supported for NSA backend now"
+ cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
+ assert forward_batch.seq_lens_cpu is not None
+ max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
+ page_table = forward_batch.req_to_token_pool.req_to_token[
+ forward_batch.req_pool_indices, :max_seqlen_k
+ ]
+
+ if forward_batch.forward_mode.is_decode_or_idle():
+ extend_seq_lens_cpu = [1] * batch_size
+ max_seqlen_q = 1
+ cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
+ seqlens_expanded = cache_seqlens_int32
+ elif forward_batch.forward_mode.is_extend():
+ assert (
+ forward_batch.extend_seq_lens_cpu is not None
+ and forward_batch.extend_seq_lens is not None
+ and forward_batch.extend_prefix_lens_cpu is not None
+ ), "All of them must not be None"
+ extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
+ assert forward_batch.extend_seq_lens is not None
+ if any(forward_batch.extend_prefix_lens_cpu):
+ max_seqlen_q = max(extend_seq_lens_cpu)
+ cu_seqlens_q = compute_cu_seqlens(
+ forward_batch.extend_seq_lens.to(torch.int32)
+ )
+ else:
+ max_seqlen_q = max_seqlen_k
+ cu_seqlens_q = cu_seqlens_k
+ seqlens_expanded = torch.cat(
+ [
+ torch.arange(
+ kv_len - qo_len + 1,
+ kv_len + 1,
+ dtype=torch.int32,
+ device=device,
+ )
+ for qo_len, kv_len in zip(
+ forward_batch.extend_seq_lens_cpu,
+ forward_batch.seq_lens_cpu.tolist(),
+ strict=True,
+ )
+ ]
+ )
+ else:
+ assert False, f"Unsupported {forward_batch.forward_mode = }"
+
+ # 1D, expanded seqlens (1D means cheap to compute, so always compute it)
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
+ original_seq_lens=seqlens_expanded,
+ nsa_index_topk=self.nsa_index_topk,
+ )
+ nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
+ nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
+
+ metadata = NSAMetadata(
+ page_size=self.real_page_size,
+ cache_seqlens_int32=cache_seqlens_int32,
+ max_seq_len_q=max_seqlen_q,
+ max_seq_len_k=max_seqlen_k,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ page_table_1=page_table,
+ flashmla_metadata=(
+ self._compute_flashmla_metadata(
+ cache_seqlens=nsa_cache_seqlens_int32,
+ seq_len_q=1, # TODO handle MTP which is not 1
+ )
+ if NSA_DECODE_IMPL == "flashmla_decode"
+ else None
+ ),
+ nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
+ nsa_cu_seqlens_q=nsa_cu_seqlens_q,
+ nsa_cu_seqlens_k=nsa_cu_seqlens_k,
+ nsa_seqlens_expanded=seqlens_expanded,
+ nsa_extend_seq_lens_list=extend_seq_lens_cpu,
+ real_page_table=self._transform_table_1_to_real(page_table),
+ )
+
+ self.forward_metadata = metadata
+
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
+ """Initialize CUDA graph state for the attention backend.
+
+ Args:
+ max_bs (int): Maximum batch size to support in CUDA graphs
+
+ This creates fixed-size tensors that will be reused during CUDA graph replay
+ to avoid memory allocations.
+ """
+ self.decode_cuda_graph_metadata: Dict = {
+ "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
+ "cu_seqlens_q": torch.arange(
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
+ ),
+ "cu_seqlens_k": torch.zeros(
+ max_bs + 1, dtype=torch.int32, device=self.device
+ ),
+ # fake page_table for sparse_prefill
+ "page_table": torch.zeros(
+ max_bs,
+ self.max_context_len,
+ dtype=torch.int32,
+ device=self.device,
+ ),
+ "flashmla_metadata": (
+ self._compute_flashmla_metadata(
+ cache_seqlens=torch.ones(
+ max_bs, dtype=torch.int32, device=self.device
+ ),
+ seq_len_q=1, # TODO handle MTP which is not 1
+ )
+ if NSA_DECODE_IMPL == "flashmla_decode"
+ else None
+ ),
+ }
+
+ def init_forward_metadata_capture_cuda_graph(
+ self,
+ bs: int,
+ num_tokens: int,
+ req_pool_indices: torch.Tensor,
+ seq_lens: torch.Tensor,
+ encoder_lens: Optional[torch.Tensor],
+ forward_mode: ForwardMode,
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
+ ):
+ """Initialize forward metadata for capturing CUDA graph."""
+ assert forward_mode.is_decode_or_idle(), "Only support decode for now"
+ assert (
+ spec_info is None
+ ), "Speculative decoding is not supported for NSA backend now"
+
+ # Normal Decode
+ # Get sequence information
+ cache_seqlens_int32 = seq_lens.to(torch.int32)
+ cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
+
+ # Use max context length for seq_len_k
+ page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
+ max_seq_len_k = page_table_1.shape[1]
+
+ # Precompute page table
+ # Precompute cumulative sequence lengths
+
+ # NOTE(dark): this is always arange, since we are decoding
+ cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
+ nsa_cache_seqlens_int32 = compute_nsa_seqlens(
+ cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
+ )
+ nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
+ nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
+ real_page_table = self._transform_table_1_to_real(page_table_1)
+
+ if NSA_DECODE_IMPL == "flashmla_decode":
+ flashmla_metadata = self.decode_cuda_graph_metadata[
+ "flashmla_metadata"
+ ].slice(slice(0, bs + 1))
+ flashmla_metadata.copy_(
+ self._compute_flashmla_metadata(
+ cache_seqlens=nsa_cache_seqlens_int32,
+ seq_len_q=1, # TODO handle MTP which is not 1
+ )
+ )
+ else:
+ flashmla_metadata = None
+
+ metadata = NSAMetadata(
+ page_size=self.real_page_size,
+ cache_seqlens_int32=cache_seqlens_int32,
+ max_seq_len_q=1,
+ max_seq_len_k=max_seq_len_k,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ page_table_1=page_table_1,
+ flashmla_metadata=flashmla_metadata,
+ nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
+ nsa_cu_seqlens_q=nsa_cu_seqlens_q,
+ nsa_cu_seqlens_k=nsa_cu_seqlens_k,
+ nsa_seqlens_expanded=cache_seqlens_int32,
+ real_page_table=real_page_table,
+ nsa_extend_seq_lens_list=[1] * bs,
+ )
+ self.decode_cuda_graph_metadata[bs] = metadata
+ self.forward_metadata = metadata
+
+ def init_forward_metadata_replay_cuda_graph(
+ self,
+ bs: int,
+ req_pool_indices: torch.Tensor,
+ seq_lens: torch.Tensor,
+ seq_lens_sum: int,
+ encoder_lens: Optional[torch.Tensor],
+ forward_mode: ForwardMode,
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
+ seq_lens_cpu: Optional[torch.Tensor],
+ out_cache_loc: Optional[torch.Tensor] = None,
+ ):
+ """Initialize forward metadata for replaying CUDA graph."""
+ assert seq_lens_cpu is not None
+ assert forward_mode.is_decode_or_idle(), "Only support decode for now"
+ assert (
+ spec_info is None
+ ), "Speculative decoding is not supported for NSA backend now"
+ seq_lens = seq_lens[:bs]
+ seq_lens_cpu = seq_lens_cpu[:bs]
+ req_pool_indices = req_pool_indices[:bs]
+
+ # Normal Decode
+ metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
+ max_len = int(seq_lens_cpu.max().item())
+
+ cache_seqlens = seq_lens.to(torch.int32)
+ metadata.cache_seqlens_int32.copy_(cache_seqlens)
+ metadata.cu_seqlens_k[1:].copy_(
+ torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
+ )
+ page_indices = self.req_to_token[req_pool_indices, :max_len]
+ metadata.page_table_1[:, :max_len].copy_(page_indices)
+ assert (
+ metadata.nsa_cache_seqlens_int32 is not None
+ and metadata.nsa_cu_seqlens_k is not None
+ and self.nsa_index_topk is not None
+ )
+ nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
+ metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
+ metadata.nsa_cu_seqlens_k[1:].copy_(
+ torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
+ )
+ # NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
+
+ assert self.real_page_size == metadata.page_size
+ if self.real_page_size > 1:
+ real_table = self._transform_table_1_to_real(page_indices)
+ new_len = real_table.shape[1]
+ metadata.real_page_table[:, :new_len].copy_(real_table)
+ else:
+ assert metadata.real_page_table is metadata.page_table_1
+
+ if NSA_DECODE_IMPL == "flashmla_decode":
+ metadata.flashmla_metadata.copy_(
+ self._compute_flashmla_metadata(
+ cache_seqlens=nsa_cache_seqlens,
+ seq_len_q=1, # TODO handle MTP which is not 1
+ )
+ )
+
+ self.forward_metadata = metadata
+
+ def forward_extend(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ layer: RadixAttention,
+ forward_batch: ForwardBatch,
+ save_kv_cache=True,
+ # For multi-head latent attention
+ q_rope: Optional[torch.Tensor] = None,
+ k_rope: Optional[torch.Tensor] = None,
+ topk_indices: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ assert (
+ not forward_batch.forward_mode.is_target_verify()
+ and not forward_batch.forward_mode.is_draft_extend()
+ ), "NSA backend doesn't support speculative decoding"
+ if k is not None:
+ assert v is not None
+ if save_kv_cache:
+ cache_loc = (
+ forward_batch.out_cache_loc
+ if not layer.is_cross_attention
+ else forward_batch.encoder_out_cache_loc
+ )
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
+ layer,
+ cache_loc,
+ k,
+ k_rope,
+ )
+
+ metadata = self.forward_metadata
+ causal = not layer.is_cross_attention
+ assert causal, "NSA is causal only"
+
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
+ kwargs = {}
+
+ # Do absorbed multi-latent attention
+ assert q_rope is not None
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
+
+ # when store in fp8 and compute in fp8, no need to convert dtype
+ if not (NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and NSA_KV_CACHE_STORE_FP8):
+ kv_cache = kv_cache.to(q.dtype)
+
+ if q_rope is not None:
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
+ q_rope = q_rope.view(
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
+ )
+ else:
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
+ q_nope = q_all[:, :, : layer.v_head_dim]
+ q_rope = q_all[:, :, layer.v_head_dim :]
+
+ # NOTE(dark): here, we use page size = 1
+
+ if NSA_FUSE_TOPK:
+ page_table_1 = topk_indices
+ else:
+ assert metadata.nsa_extend_seq_lens_list is not None
+ page_table_1 = transform_index_page_table_prefill(
+ page_table=metadata.page_table_1,
+ topk_indices=topk_indices,
+ extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
+ page_size=1,
+ )
+ if NSA_PREFILL_IMPL == "tilelang":
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import (
+ tilelang_sparse_fwd,
+ )
+
+ if q_rope is not None:
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
+ return self._forward_tilelang(
+ q_all=q_all,
+ kv_cache=kv_cache,
+ page_table_1=page_table_1,
+ sm_scale=layer.scaling,
+ v_head_dim=layer.v_head_dim,
+ )
+ elif NSA_PREFILL_IMPL == "flashmla_prefill":
+ if q_rope is not None:
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
+ return self._forward_flashmla_prefill(
+ q_all=q_all,
+ kv_cache=kv_cache,
+ page_table_1=page_table_1,
+ sm_scale=layer.scaling,
+ v_head_dim=layer.v_head_dim,
+ )
+ elif NSA_PREFILL_IMPL == "flashmla_decode":
+ if q_rope is not None:
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
+ return self._forward_flashmla_decode(
+ q_all=q_all,
+ kv_cache=kv_cache,
+ sm_scale=layer.scaling,
+ v_head_dim=layer.v_head_dim,
+ # TODO optimize args
+ layer=layer,
+ forward_batch=forward_batch,
+ metadata=metadata,
+ topk_indices=topk_indices,
+ block_table=metadata.real_page_table,
+ )
+ elif NSA_PREFILL_IMPL == "fa3":
+ return self._forward_fa3(
+ q_rope=q_rope,
+ kv_cache=kv_cache,
+ v_head_dim=layer.v_head_dim,
+ q_nope=q_nope,
+ page_table=page_table_1,
+ cache_seqlens=metadata.nsa_cache_seqlens_int32,
+ cu_seqlens_q=metadata.nsa_cu_seqlens_q,
+ cu_seqlens_k=metadata.nsa_cu_seqlens_k,
+ max_seqlen_q=metadata.nsa_max_seqlen_q,
+ sm_scale=layer.scaling,
+ logit_cap=layer.logit_cap,
+ page_size=1,
+ )
+ else:
+ raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
+
+ def forward_decode(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ layer: RadixAttention,
+ forward_batch: ForwardBatch,
+ save_kv_cache=True,
+ # For multi-head latent attention
+ q_rope: Optional[torch.Tensor] = None,
+ k_rope: Optional[torch.Tensor] = None,
+ topk_indices: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if k is not None:
+ assert v is not None
+ if save_kv_cache:
+ cache_loc = (
+ forward_batch.out_cache_loc
+ if not layer.is_cross_attention
+ else forward_batch.encoder_out_cache_loc
+ )
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
+ layer,
+ cache_loc,
+ k,
+ k_rope,
+ )
+
+ metadata = self.forward_metadata
+ causal = not layer.is_cross_attention
+ assert causal, "NSA is causal only"
+
+ # Do absorbed multi-latent attention
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
+ if q_rope is not None:
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
+ q_rope = q_rope.view(
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
+ )
+ else:
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
+ q_nope = q_all[:, :, : layer.v_head_dim]
+ q_rope = q_all[:, :, layer.v_head_dim :]
+
+ if NSA_FUSE_TOPK:
+ page_table_1 = topk_indices
+ else:
+ page_table_1 = transform_index_page_table_decode(
+ page_table=metadata.page_table_1,
+ topk_indices=topk_indices,
+ page_size=1,
+ )
+
+ if NSA_DECODE_IMPL == "flashmla_prefill":
+ if q_rope is not None:
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
+ return self._forward_flashmla_prefill(
+ q_all=q_all,
+ kv_cache=kv_cache,
+ page_table_1=page_table_1,
+ sm_scale=layer.scaling,
+ v_head_dim=layer.v_head_dim,
+ )
+ elif NSA_DECODE_IMPL == "flashmla_decode":
+ if q_rope is not None:
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
+ return self._forward_flashmla_decode(
+ q_all=q_all,
+ kv_cache=kv_cache,
+ sm_scale=layer.scaling,
+ v_head_dim=layer.v_head_dim,
+ # TODO optimize args
+ layer=layer,
+ forward_batch=forward_batch,
+ metadata=metadata,
+ topk_indices=topk_indices,
+ block_table=metadata.real_page_table,
+ )
+ elif NSA_DECODE_IMPL == "tilelang":
+ if q_rope is not None:
+ q_all = torch.cat([q_nope, q_rope], dim=-1)
+ return self._forward_tilelang(
+ q_all=q_all,
+ kv_cache=kv_cache,
+ page_table_1=page_table_1,
+ sm_scale=layer.scaling,
+ v_head_dim=layer.v_head_dim,
+ )
+ elif NSA_DECODE_IMPL == "fa3":
+ return self._forward_fa3(
+ q_rope=q_rope,
+ kv_cache=kv_cache,
+ v_head_dim=layer.v_head_dim,
+ q_nope=q_nope,
+ page_table=page_table_1,
+ cache_seqlens=metadata.nsa_cache_seqlens_int32,
+ cu_seqlens_q=metadata.nsa_cu_seqlens_q,
+ cu_seqlens_k=metadata.nsa_cu_seqlens_k,
+ max_seqlen_q=metadata.nsa_max_seqlen_q,
+ sm_scale=layer.scaling,
+ logit_cap=layer.logit_cap,
+ page_size=1,
+ )
+ else:
+ assert False, f"Unsupported {NSA_DECODE_IMPL = }"
+
+ def _forward_fa3(
+ self,
+ q_rope: torch.Tensor,
+ kv_cache: torch.Tensor,
+ v_head_dim: int,
+ q_nope: torch.Tensor,
+ page_table: torch.Tensor,
+ cache_seqlens: torch.Tensor,
+ cu_seqlens_q: torch.Tensor,
+ cu_seqlens_k: torch.Tensor,
+ max_seqlen_q: int,
+ sm_scale: float,
+ logit_cap: float,
+ page_size: int,
+ ) -> torch.Tensor:
+ k_rope_cache = kv_cache[:, :, v_head_dim:]
+ c_kv_cache = kv_cache[:, :, :v_head_dim]
+ qk_rope_dim = k_rope_cache.shape[-1]
+ k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
+ c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
+ o = flash_attn_with_kvcache(
+ q=q_rope,
+ k_cache=k_rope_cache,
+ v_cache=c_kv_cache,
+ qv=q_nope,
+ page_table=page_table,
+ cache_seqlens=cache_seqlens,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k_new=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ softmax_scale=sm_scale,
+ causal=True,
+ softcap=logit_cap,
+ return_softmax_lse=False,
+ num_splits=self.num_splits,
+ )
+ return o # type: ignore
+
+ def _forward_flashmla_prefill(
+ self,
+ q_all: torch.Tensor,
+ kv_cache: torch.Tensor,
+ v_head_dim: int,
+ page_table_1: torch.Tensor,
+ sm_scale: float,
+ ) -> torch.Tensor:
+ from flash_mla import flash_mla_sparse_fwd
+
+ o, _, _ = flash_mla_sparse_fwd(
+ q=q_all,
+ kv=kv_cache,
+ indices=page_table_1.unsqueeze(1),
+ sm_scale=sm_scale,
+ d_v=v_head_dim,
+ )
+ return o
+
+ def _forward_flashmla_decode(
+ self,
+ q_all: torch.Tensor,
+ kv_cache: torch.Tensor,
+ v_head_dim: int,
+ sm_scale: float,
+ layer,
+ forward_batch: ForwardBatch,
+ metadata: NSAMetadata,
+ topk_indices,
+ block_table,
+ ) -> torch.Tensor:
+ from flash_mla import flash_mla_with_kvcache
+
+ cache_seqlens = metadata.nsa_cache_seqlens_int32
+
+ # TODO the 2nd dim is seq_len_q, need to be >1 when MTP
+ q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
+ kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
+ assert self.real_page_size == 64, "only page size 64 is supported"
+
+ if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not NSA_KV_CACHE_STORE_FP8:
+ # inefficiently quantize the whole cache
+ kv_cache = quantize_k_cache(kv_cache)
+
+ o, _ = flash_mla_with_kvcache(
+ q=q_all,
+ k_cache=kv_cache,
+ cache_seqlens=cache_seqlens,
+ head_dim_v=v_head_dim,
+ tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
+ num_splits=metadata.flashmla_metadata.num_splits,
+ softmax_scale=sm_scale,
+ # TODO improve
+ indices=_compute_indices_in_kvcache(
+ block_table=block_table,
+ topk_indices=topk_indices.to(torch.int32),
+ page_size=self.real_page_size,
+ nsa_index_topk=self.nsa_index_topk,
+ ),
+ # doc says it is not used, but if pass in None then error
+ block_table=torch.empty(
+ (q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
+ ),
+ is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
+ )
+
+ # TODO shape correct?
+ return o
+
+ def _forward_tilelang(
+ self,
+ q_all: torch.Tensor,
+ kv_cache: torch.Tensor,
+ v_head_dim: int,
+ page_table_1: torch.Tensor,
+ sm_scale: float,
+ ) -> torch.Tensor:
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
+
+ return tilelang_sparse_fwd(
+ q=q_all,
+ kv=kv_cache,
+ indices=page_table_1.unsqueeze(1),
+ sm_scale=sm_scale,
+ d_v=v_head_dim,
+ )
+
+ def get_cuda_graph_seq_len_fill_value(self):
+ """Get the fill value for sequence length in CUDA graph."""
+ return 1
+
+ def get_indexer_metadata(
+ self, layer_id: int, forward_batch: ForwardBatch
+ ) -> NSAIndexerMetadata:
+ return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
+
+ def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
+ from flash_mla import get_mla_metadata
+
+ flashmla_metadata, num_splits = get_mla_metadata(
+ cache_seqlens=cache_seqlens,
+ # TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
+ # but the name looks like need seq_len_q?
+ num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
+ num_heads_k=1,
+ num_heads_q=self.num_q_heads,
+ is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
+ topk=self.nsa_index_topk,
+ )
+
+ return NSAFlashMLAMetadata(
+ flashmla_metadata=flashmla_metadata,
+ num_splits=num_splits,
+ )
+
+
+# TODO speedup
+def _compute_indices_in_kvcache(block_table, topk_indices, page_size, nsa_index_topk):
+ topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
+
+ idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
+ 1
+ )
+ block_idx = block_table[idx0, topk_indices_safe // page_size]
+ offset = topk_indices_safe % page_size
+ indices_in_kvcache = block_idx * page_size + offset
+
+ # the kernel requires invalid entry to be -1
+ assert indices_in_kvcache.shape == topk_indices.shape
+ indices_in_kvcache[topk_indices == -1] = -1
+
+ # return: (batch_size, seqlen_q_ori, topk)
+ indices_in_kvcache = indices_in_kvcache[:, None, :]
+
+ indices_in_kvcache = torch.nn.functional.pad(
+ indices_in_kvcache,
+ (0, nsa_index_topk - indices_in_kvcache.shape[-1]),
+ "constant",
+ -1,
+ )
+ assert indices_in_kvcache.shape[-1] == nsa_index_topk
+
+ return indices_in_kvcache
diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
index 7a3f31128..185764ad7 100755
--- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py
+++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py
@@ -127,8 +127,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"disable_chunked_prefix_cache"
]
- self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
-
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
@@ -219,7 +217,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes.
- if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
+ if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
@@ -230,9 +228,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info,
)
- if forward_mode.is_target_verify():
- seq_lens = seq_lens + self.num_draft_tokens
-
# Custom fast-path for decode/idle.
# Capture with full width so future longer sequences are safe during replay
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
@@ -275,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
):
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes.
- if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
+ if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
@@ -287,10 +282,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu,
)
- if forward_mode.is_target_verify():
- seq_lens = seq_lens + self.num_draft_tokens
- del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
-
metadata = self.decode_cuda_graph_metadata[bs]
# Update block indices for new sequences.
@@ -341,10 +332,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
cum_seq_lens_q,
seq_lens,
)
- elif (
- forward_batch.forward_mode.is_decode_or_idle()
- or forward_batch.forward_mode.is_target_verify()
- ):
+ elif forward_batch.forward_mode.is_decode_or_idle():
bs = forward_batch.batch_size
# Get maximum sequence length.
@@ -353,19 +341,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
else:
max_seq = forward_batch.seq_lens.max().item()
- seq_lens = forward_batch.seq_lens
-
- if forward_batch.forward_mode.is_target_verify():
- max_seq = max_seq + self.num_draft_tokens
- seq_lens = seq_lens + self.num_draft_tokens
-
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
- seq_lens,
- seq_lens.device,
+ forward_batch.seq_lens,
+ forward_batch.seq_lens.device,
)
max_seq_len_val = int(max_seq)
@@ -505,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
q_rope_reshaped = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
- query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
+ if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
+ query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
+ else:
+ query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
else:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
@@ -568,134 +553,84 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- if forward_batch.forward_mode.is_draft_extend():
+ ):
+ if (
+ forward_batch.forward_mode.is_target_verify()
+ or forward_batch.forward_mode.is_draft_extend()
+ ):
+ return super().forward_extend(
+ q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
+ )
+ # chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
+ if forward_batch.attn_attend_prefix_cache is None:
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
- # Save KV cache if requested
- if save_kv_cache:
- assert (
- k is not None and k_rope is not None
- ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
- forward_batch.token_to_kv_pool.set_mla_kv_buffer(
- layer, forward_batch.out_cache_loc, k, k_rope
- )
-
- if q_rope is not None:
- q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
- q_rope = q_rope.view(
- -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
- )
- q = _concat_mla_absorb_q_general(q, q_rope)
-
- q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
-
- if k_rope is not None:
- k = torch.cat([k, k_rope], dim=-1)
- k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
-
- v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
-
- if forward_batch.forward_mode.is_target_verify():
- metadata = (
- getattr(forward_batch, "decode_trtllm_mla_metadata", None)
- or self.forward_decode_metadata
- )
-
- # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
- bs = forward_batch.batch_size
- q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
-
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
- kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
-
- q_scale = 1.0
- k_scale = (
- layer.k_scale_float
- if getattr(layer, "k_scale_float", None) is not None
- else 1.0
- )
-
- bmm1_scale = q_scale * k_scale * layer.scaling
-
- seq_lens = (
- forward_batch.seq_lens.to(torch.int32)
- + forward_batch.spec_info.draft_token_num
- )
- max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
-
- # TODO may use `mla_rope_quantize_fp8` fusion
- q = q.to(self.data_type)
- assert kv_cache.dtype == self.data_type
-
- raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
- query=q,
- kv_cache=kv_cache,
- workspace_buffer=self.workspace_buffer,
- qk_nope_head_dim=self.qk_nope_head_dim,
- kv_lora_rank=self.kv_lora_rank,
- qk_rope_head_dim=self.qk_rope_head_dim,
- block_tables=metadata.block_kv_indices,
- seq_lens=seq_lens,
- max_seq_len=max_seq_len,
- bmm1_scale=bmm1_scale,
- )
-
- # Reshape output directly without slicing
- output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
- return output
-
- if forward_batch.attn_attend_prefix_cache:
- # MHA for chunked prefix kv cache when running model with MLA
- assert forward_batch.prefix_chunk_idx is not None
- assert forward_batch.prefix_chunk_cu_seq_lens is not None
- assert q_rope is None
- assert k_rope is None
- chunk_idx = forward_batch.prefix_chunk_idx
-
- output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
- return flashinfer.prefill.trtllm_ragged_attention_deepseek(
+ if not forward_batch.attn_attend_prefix_cache:
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
+ k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
+ v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
+ output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
- seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
+ seq_lens=self.forward_prefill_metadata.seq_lens,
max_q_len=self.forward_prefill_metadata.max_seq_len,
- max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
+ max_kv_len=self.forward_prefill_metadata.max_seq_len,
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
- o_sf_scale=-1.0,
+ o_sf_scale=1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
- cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
+ cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
enable_pdl=False,
- is_causal=False,
- return_lse=True,
- out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
+ is_causal=True,
+ return_lse=forward_batch.mha_return_lse,
)
+ else:
+ if not (
+ forward_batch.attn_attend_prefix_cache is not None
+ and forward_batch.mha_return_lse
+ ):
+ output = super().forward_extend(
+ q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
+ )
+ else:
+ # MHA for chunked prefix kv cache when running model with MLA
+ assert forward_batch.prefix_chunk_idx is not None
+ assert forward_batch.prefix_chunk_cu_seq_lens is not None
+ assert q_rope is None
+ assert k_rope is None
+ chunk_idx = forward_batch.prefix_chunk_idx
- return flashinfer.prefill.trtllm_ragged_attention_deepseek(
- query=q,
- key=k,
- value=v,
- workspace_buffer=self.workspace_buffer,
- seq_lens=self.forward_prefill_metadata.seq_lens,
- max_q_len=self.forward_prefill_metadata.max_seq_len,
- max_kv_len=self.forward_prefill_metadata.max_seq_len,
- bmm1_scale=layer.scaling,
- bmm2_scale=1.0,
- o_sf_scale=1.0,
- batch_size=forward_batch.batch_size,
- window_left=-1,
- cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
- cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
- enable_pdl=False,
- is_causal=True,
- return_lse=forward_batch.mha_return_lse,
- )
+ q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
+ k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
+ v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
+ output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
+ output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
+ query=q,
+ key=k,
+ value=v,
+ workspace_buffer=self.workspace_buffer,
+ seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
+ max_q_len=self.forward_prefill_metadata.max_seq_len,
+ max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
+ bmm1_scale=layer.scaling,
+ bmm2_scale=1.0,
+ o_sf_scale=-1.0,
+ batch_size=forward_batch.batch_size,
+ window_left=-1,
+ cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
+ cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
+ enable_pdl=False,
+ is_causal=False,
+ return_lse=True,
+ out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
+ )
+ return output
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
@@ -713,10 +648,3 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
-
-
-def _concat_mla_absorb_q_general(q_nope, q_rope):
- if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
- return concat_mla_absorb_q(q_nope, q_rope)
- else:
- return torch.cat([q_nope, q_rope], dim=-1)
diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py
index 489b8248b..2be3e450b 100644
--- a/python/sglang/srt/layers/attention/vision.py
+++ b/python/sglang/srt/layers/attention/vision.py
@@ -16,19 +16,14 @@ from sglang.srt.utils import (
get_device_capability,
is_blackwell,
is_cuda,
- is_npu,
print_info_once,
)
_is_cuda = is_cuda()
-_is_npu = is_npu()
if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func
-if _is_npu:
- import torch_npu
-
from sglang.srt.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
@@ -336,63 +331,10 @@ class VisionFlash3Attention(nn.Module):
return output
-class VisionAscendAttention(nn.Module):
-
- def __init__(
- self,
- **kwargs,
- ):
- if not _is_npu:
- raise Exception("VisionAscendAttention is only available for ascend npu")
- super().__init__()
-
- def forward(
- self,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
- bsz: int,
- seq_len: int,
- **kwargs,
- ) -> torch.Tensor:
- r"""
- Args:
- cu_seqlens: [b]
- Returns:
- [b * s, h, head_size]
- """
- if cu_seqlens is None:
- cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
-
- seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
- if seq_lens.is_npu:
- # cu_seqlens must be on cpu because of operator restriction
- seq_lens = seq_lens.to("cpu")
- _, num_heads, head_size = q.shape
- num_kv_heads = k.shape[1]
- output = torch.empty_like(q)
-
- # operator requires pta version >= 2.5.1
- torch_npu._npu_flash_attention_unpad(
- query=q,
- key=k,
- value=v,
- seq_len=seq_lens.to(torch.int32),
- scale_value=head_size**-0.5,
- num_heads=num_heads,
- num_kv_heads=num_kv_heads,
- out=output,
- )
-
- return output
-
-
QKV_BACKEND_IMPL = {
"triton_attn": VisionTritonAttention,
"sdpa": VisionSdpaAttention,
"fa3": VisionFlash3Attention,
- "ascend_attn": VisionAscendAttention,
}
diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py
index e050da91d..fba8d8f18 100644
--- a/python/sglang/srt/layers/communicator.py
+++ b/python/sglang/srt/layers/communicator.py
@@ -50,7 +50,6 @@ from sglang.srt.utils import (
is_hip,
is_sm90_supported,
is_sm100_supported,
- prepare_weight_cache,
)
_is_flashinfer_available = is_flashinfer_available()
@@ -276,11 +275,7 @@ class LayerCommunicator:
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
- cache=None,
):
- if cache is not None:
- self._context.cache = cache
-
return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states,
residual=residual,
@@ -354,7 +349,6 @@ class CommunicateContext:
attn_tp_size: int
attn_dp_size: int
tp_size: int
- cache = None
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b]
@@ -539,8 +533,6 @@ class CommunicateWithAllReduceAndLayerNormFn:
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
- if context.cache is not None:
- _ = prepare_weight_cache(hidden_states, context.cache)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
diff --git a/python/sglang/srt/layers/elementwise.py b/python/sglang/srt/layers/elementwise.py
index 899518034..e05d88b32 100644
--- a/python/sglang/srt/layers/elementwise.py
+++ b/python/sglang/srt/layers/elementwise.py
@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
assert len(x.shape) == 2
- assert (
- x.shape == residual.shape and x.dtype == residual.dtype
- ), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
+ assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape
if autotune:
diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py
index 4c1e8ddfa..87a392d55 100644
--- a/python/sglang/srt/layers/layernorm.py
+++ b/python/sglang/srt/layers/layernorm.py
@@ -127,45 +127,21 @@ class RMSNorm(CustomOp):
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
- # def forward_hip(
- # self,
- # x: torch.Tensor,
- # residual: Optional[torch.Tensor] = None,
- # ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- # if not x.is_contiguous():
- # # NOTE: Remove this if aiter kernel supports discontinuous input
- # x = x.contiguous()
- # if residual is not None:
- # if _vllm_version < Version("0.9"):
- # fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
- # return x, residual
- # else:
- # residual_out = torch.empty_like(x)
- # output = torch.empty_like(x)
- # fused_add_rms_norm(
- # output,
- # x,
- # residual_out,
- # residual,
- # self.weight.data,
- # self.variance_epsilon,
- # )
- # return output, residual_out
- # out = torch.empty_like(x)
- # rms_norm(out, x, self.weight.data, self.variance_epsilon)
- # return out
def forward_hip(
- self,
- x: torch.Tensor,
- residual: Optional[torch.Tensor] = None,
- ):
+ self,
+ x: torch.Tensor,
+ residual: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
+ # NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous()
-
if residual is not None:
- try:
- output = torch.empty_like(x)
+ if _vllm_version < Version("0.9"):
+ fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
+ return x, residual
+ else:
residual_out = torch.empty_like(x)
+ output = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
@@ -175,21 +151,10 @@ class RMSNorm(CustomOp):
self.variance_epsilon,
)
return output, residual_out
- except TypeError:
- fused_add_rms_norm(
- x,
- residual,
- self.weight.data,
- self.variance_epsilon,
- )
- return x, residual
-
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out
-
-
def forward_native(
self,
x: torch.Tensor,
diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py
index 2b34a2965..0765b673a 100644
--- a/python/sglang/srt/layers/linear.py
+++ b/python/sglang/srt/layers/linear.py
@@ -31,7 +31,6 @@ from sglang.srt.layers.parameter import (
_ColumnvLLMParameter,
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
-from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
if TYPE_CHECKING:
@@ -626,16 +625,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
- # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
- end_idx = start_idx + shard_size
- if end_idx > loaded_weight.shape[output_dim]:
- loaded_weight = pad_or_narrow_weight(
- loaded_weight, output_dim, start_idx, shard_size
- )
- else:
- loaded_weight = loaded_weight.narrow(
- output_dim, start_idx, shard_size
- )
+ loaded_weight = loaded_weight.narrow(
+ output_dim, start_idx, shard_size
+ )
# Special case for AQLM codebooks.
elif is_metadata:
@@ -1310,16 +1302,7 @@ class RowParallelLinear(LinearBase):
shard_size,
)
else:
- # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
- end_idx = start_idx + shard_size
- if end_idx > loaded_weight.shape[input_dim]:
- loaded_weight = pad_or_narrow_weight(
- loaded_weight, input_dim, start_idx, shard_size
- )
- else:
- loaded_weight = loaded_weight.narrow(
- input_dim, start_idx, shard_size
- )
+ loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index 5f9651086..e39727842 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -220,7 +220,6 @@ class LogitsProcessor(nn.Module):
self.config = config
self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
- self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
@@ -462,11 +461,7 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"):
- if self.use_fp32_lm_head:
- logits = torch.matmul(
- hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
- )
- elif use_intel_amx_backend(lm_head):
+ if use_intel_amx_backend(lm_head):
logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
@@ -480,15 +475,7 @@ class LogitsProcessor(nn.Module):
else:
# GGUF models
# TODO: use weight_packed_linear for GGUF models
- if self.use_fp32_lm_head:
- with torch.cuda.amp.autocast(enabled=False):
- logits = lm_head.quant_method.apply(
- lm_head, hidden_states.to(torch.float32), embedding_bias
- )
- else:
- logits = lm_head.quant_method.apply(
- lm_head, hidden_states, embedding_bias
- )
+ logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
if self.logit_scale is not None:
logits.mul_(self.logit_scale)
diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py
index 0bd49600e..c72689c07 100644
--- a/python/sglang/srt/layers/moe/ep_moe/layer.py
+++ b/python/sglang/srt/layers/moe/ep_moe/layer.py
@@ -789,45 +789,69 @@ class DeepEPMoE(EPMoE):
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
- else:
- # dynamic quant
- hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
- hidden_states
- )
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device
)
+ if self.w13_weight.dtype != torch.int8:
+ # gmm1: gate_up_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[self.w13_weight.permute(0, 2, 1)],
+ # per_token_scale=[per_token_scale],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=output_dtype,
+ )[0]
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
+ # gmm2: down_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[self.w2_weight.permute(0, 2, 1)],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=output_dtype,
+ )[0]
+ else:
+ if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
+ hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
+ hidden_states
+ )
+ # gmm1: gate_up_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[self.w13_weight],
+ scale=[self.w13_weight_scale.to(output_dtype)],
+ per_token_scale=[per_token_scale],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=output_dtype,
+ )[0]
- # gmm1: gate_up_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[self.w13_weight],
- scale=[self.w13_weight_scale.to(output_dtype)],
- per_token_scale=[per_token_scale],
- split_item=2,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=output_dtype,
- )[0]
+ # act_fn: swiglu
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
+ hidden_states
+ )
- # act_fn: swiglu
- hidden_states = torch_npu.npu_swiglu(hidden_states)
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
-
- # gmm2: down_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[self.w2_weight],
- scale=[self.w2_weight_scale.to(output_dtype)],
- per_token_scale=[swiglu_out_scale],
- split_item=2,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=output_dtype,
- )[0]
+ # gmm2: down_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[self.w2_weight],
+ scale=[self.w2_weight_scale.to(output_dtype)],
+ per_token_scale=[swiglu_out_scale],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=output_dtype,
+ )[0]
return hidden_states
@@ -836,47 +860,72 @@ class DeepEPMoE(EPMoE):
assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
- per_token_scale = hidden_states[1]
- hidden_states = hidden_states[0]
+ if isinstance(hidden_states, tuple):
+ per_token_scale = hidden_states[1]
+ hidden_states = hidden_states[0]
group_list = group_list.to(torch.int64)
- # gmm1: gate_up_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[self.w13_weight],
- split_item=2,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=torch.int32,
- )[0]
+ if self.w13_weight.dtype != torch.int8:
+ # gmm1: gate_up_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[self.w13_weight.permute(0, 2, 1)],
+ # per_token_scale=[per_token_scale],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=output_dtype,
+ )[0]
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
+ # gmm2: down_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[self.w2_weight.permute(0, 2, 1)],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=output_dtype,
+ )[0]
+ else:
+ # gmm1: gate_up_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[self.w13_weight],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=torch.int32,
+ )[0]
- # act_fn: swiglu
- hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
- x=hidden_states,
- weight_scale=self.w13_weight_scale.to(torch.float32),
- activation_scale=per_token_scale,
- bias=None,
- quant_scale=None,
- quant_offset=None,
- group_index=group_list,
- activate_left=True,
- quant_mode=1,
- )
+ # act_fn: swiglu
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
+ x=hidden_states,
+ weight_scale=self.w13_weight_scale.to(torch.float32),
+ activation_scale=per_token_scale,
+ bias=None,
+ quant_scale=None,
+ quant_offset=None,
+ group_index=group_list,
+ activate_left=True,
+ quant_mode=1,
+ )
- # gmm2: down_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[self.w2_weight],
- scale=[self.w2_weight_scale.to(output_dtype)],
- per_token_scale=[swiglu_out_scale],
- split_item=2,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=output_dtype,
- )[0]
+ # gmm2: down_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[self.w2_weight],
+ scale=[self.w2_weight_scale.to(output_dtype)],
+ per_token_scale=[swiglu_out_scale],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=output_dtype,
+ )[0]
return hidden_states
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
deleted file mode 100644
index 8e49def8d..000000000
--- a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
+++ /dev/null
@@ -1,146 +0,0 @@
-{
- "1": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 64,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 4
- },
- "2": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "4": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 64,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 64,
- "num_warps": 4,
- "num_stages": 4
- },
- "8": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "16": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "24": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "32": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "48": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "64": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "96": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "128": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "256": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
- "512": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "1024": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
- "1536": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
- "2048": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
- "3072": {
- "BLOCK_SIZE_M": 128,
- "BLOCK_SIZE_N": 64,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 16,
- "num_warps": 4,
- "num_stages": 3
- },
- "4096": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- }
-}
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
deleted file mode 100644
index 01689145a..000000000
--- a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json
+++ /dev/null
@@ -1,146 +0,0 @@
-{
- "1": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "2": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "4": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 4
- },
- "8": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "16": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "24": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "32": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "48": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "64": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "96": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "128": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 2
- },
- "256": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 2
- },
- "512": {
- "BLOCK_SIZE_M": 16,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "1024": {
- "BLOCK_SIZE_M": 32,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "1536": {
- "BLOCK_SIZE_M": 32,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 1,
- "num_warps": 4,
- "num_stages": 3
- },
- "2048": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
- "3072": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 3
- },
- "4096": {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": 128,
- "BLOCK_SIZE_K": 128,
- "GROUP_SIZE_M": 64,
- "num_warps": 4,
- "num_stages": 4
- }
-}
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
index 0c2939935..06e57f1e6 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py
@@ -51,14 +51,10 @@ def get_moe_configs(
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
- config_dir = os.environ.get(
- "SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
- )
-
triton_version = triton.__version__
version_dir = f"triton_{triton_version.replace('.', '_')}"
config_file_path = os.path.join(
- config_dir,
+ os.path.dirname(os.path.realpath(__file__)),
"configs",
version_dir,
json_file_name,
@@ -79,7 +75,7 @@ def get_moe_configs(
if try_triton_version == triton_version:
continue
try_config_file_path = os.path.join(
- config_dir,
+ os.path.dirname(os.path.realpath(__file__)),
"configs",
f"triton_{try_triton_version.replace('.', '_')}",
json_file_name,
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
index 241f8b142..81355c4f9 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
@@ -575,10 +575,7 @@ class FusedMoE(torch.nn.Module):
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
- if (
- should_use_flashinfer_trtllm_moe()
- and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
- ):
+ if should_use_flashinfer_trtllm_moe():
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py
index 3cc1d2344..1ea75d70c 100644
--- a/python/sglang/srt/layers/parameter.py
+++ b/python/sglang/srt/layers/parameter.py
@@ -7,7 +7,6 @@ from typing import Callable, Optional, Union
import torch
from torch.nn import Parameter
-from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu
__all__ = [
@@ -157,17 +156,9 @@ class _ColumnvLLMParameter(BasevLLMParameter):
)
else:
if not use_presharded_weights:
- # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
- start_idx = tp_rank * shard_size
- end_idx = start_idx + shard_size
- if end_idx > loaded_weight.shape[self.output_dim]:
- loaded_weight = pad_or_narrow_weight(
- loaded_weight, self.output_dim, start_idx, shard_size
- )
- else:
- loaded_weight = loaded_weight.narrow(
- self.output_dim, start_idx, shard_size
- )
+ loaded_weight = loaded_weight.narrow(
+ self.output_dim, tp_rank * shard_size, shard_size
+ )
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
@@ -267,17 +258,9 @@ class RowvLLMParameter(BasevLLMParameter):
return
else:
- # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
- start_idx = tp_rank * shard_size
- end_idx = start_idx + shard_size
- if end_idx > loaded_weight.shape[self.input_dim]:
- loaded_weight = pad_or_narrow_weight(
- loaded_weight, self.input_dim, start_idx, shard_size
- )
- else:
- loaded_weight = loaded_weight.narrow(
- self.input_dim, start_idx, shard_size
- )
+ loaded_weight = loaded_weight.narrow(
+ self.input_dim, tp_rank * shard_size, shard_size
+ )
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
index 145edbbdf..8afc15a73 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -30,7 +30,6 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
CompressedTensorsW8A8Fp8,
- CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
)
from sglang.srt.layers.quantization.compressed_tensors.utils import (
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
index 2476da700..c94575316 100644
--- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
+++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
@@ -2,12 +2,10 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
-from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsW8A8Fp8",
"CompressedTensorsW8A16Fp8",
- "CompressedTensorsW8A8Int8",
]
diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
deleted file mode 100644
index 9bca2834d..000000000
--- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+++ /dev/null
@@ -1,173 +0,0 @@
-# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
-# SPDX-License-Identifier: Apache-2.0
-
-from typing import Callable, Optional
-
-import torch
-from compressed_tensors.quantization import QuantizationStrategy
-from torch.nn import Parameter
-
-from sglang.srt.layers.parameter import (
- ChannelQuantScaleParameter,
- ModelWeightParameter,
- PerTensorScaleParameter,
-)
-from sglang.srt.layers.quantization.compressed_tensors.schemes import (
- CompressedTensorsScheme,
-)
-from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
-from sglang.srt.layers.quantization.utils import requantize_with_max_scale
-from sglang.srt.utils import is_cuda
-
-_is_cuda = is_cuda()
-if _is_cuda:
- from sgl_kernel import int8_scaled_mm
-
-
-class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
-
- def __init__(
- self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
- ):
- self.strategy = strategy
- self.is_static_input_scheme = is_static_input_scheme
- self.input_symmetric = input_symmetric
-
- @classmethod
- def get_min_capability(cls) -> int:
- # lovelace and up
- return 89
-
- def process_weights_after_loading(self, layer) -> None:
- # If per tensor, when we have a fused module (e.g. QKV) with per
- # tensor scales (thus N scales being passed to the kernel),
- # requantize so we can always run per channel
- if self.strategy == QuantizationStrategy.TENSOR:
- max_w_scale, weight = requantize_with_max_scale(
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- logical_widths=layer.logical_widths,
- )
-
- layer.weight = Parameter(weight.t(), requires_grad=False)
- layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
-
- # If channelwise, scales are already lined up, so just transpose.
- elif self.strategy == QuantizationStrategy.CHANNEL:
- weight = layer.weight
- weight_scale = layer.weight_scale.data
-
- layer.weight = Parameter(weight.t(), requires_grad=False)
- # required by torch.compile to be torch.nn.Parameter
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
-
- else:
- raise ValueError(f"Unknown quantization strategy {self.strategy}")
-
- # INPUT SCALE
- if self.is_static_input_scheme and hasattr(layer, "input_scale"):
- if self.input_symmetric:
- layer.input_scale = Parameter(
- layer.input_scale.max(), requires_grad=False
- )
- else:
- input_scale = layer.input_scale
- input_zero_point = layer.input_zero_point
-
- # reconstruct the ranges
- int8_traits = torch.iinfo(torch.int8)
- azps = input_zero_point.to(dtype=torch.int32)
- range_max = (input_scale * (int8_traits.max - azps)).max()
- range_min = (input_scale * (int8_traits.min - azps)).min()
-
- scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
-
- # AZP loaded as int8 but used as int32
- azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
-
- layer.input_scale = Parameter(scale, requires_grad=False)
- layer.input_zero_point = Parameter(azp, requires_grad=False)
- else:
- layer.input_scale = None
- layer.input_zero_point = None
-
- # azp_adj is the AZP adjustment term, used to account for weights.
- # It does not depend on scales or azp, so it is the same for
- # static and dynamic quantization.
- # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
- # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
- if not self.input_symmetric:
- weight = layer.weight
- azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
- if self.is_static_input_scheme:
- # cutlass_w8a8 requires azp to be folded into azp_adj
- # in the per-tensor case
- azp_adj = layer.input_zero_point * azp_adj
- layer.azp_adj = Parameter(azp_adj, requires_grad=False)
- else:
- layer.azp_adj = None
-
- def create_weights(
- self,
- layer: torch.nn.Module,
- output_partition_sizes: list[int],
- input_size_per_partition: int,
- params_dtype: torch.dtype,
- weight_loader: Callable,
- **kwargs,
- ):
- output_size_per_partition = sum(output_partition_sizes)
- layer.logical_widths = output_partition_sizes
-
- # WEIGHT
- weight = ModelWeightParameter(
- data=torch.empty(
- output_size_per_partition, input_size_per_partition, dtype=torch.int8
- ),
- input_dim=1,
- output_dim=0,
- weight_loader=weight_loader,
- )
-
- layer.register_parameter("weight", weight)
-
- # WEIGHT SCALE
- if self.strategy == QuantizationStrategy.CHANNEL:
- weight_scale = ChannelQuantScaleParameter(
- data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
- output_dim=0,
- weight_loader=weight_loader,
- )
- else:
- assert self.strategy == QuantizationStrategy.TENSOR
- weight_scale = PerTensorScaleParameter(
- data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
- weight_loader=weight_loader,
- )
- layer.register_parameter("weight_scale", weight_scale)
-
- # INPUT SCALE
- if self.is_static_input_scheme:
- input_scale = PerTensorScaleParameter(
- data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
- )
- layer.register_parameter("input_scale", input_scale)
-
- if not self.input_symmetric:
- # Note: compressed-tensors stores the zp using the same dtype
- # as the weights
- # AZP loaded as int8 but used as int32
- input_zero_point = PerTensorScaleParameter(
- data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
- )
- layer.register_parameter("input_zero_point", input_zero_point)
-
- def apply_weights(
- self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
- ) -> torch.Tensor:
- # TODO: add cutlass_scaled_mm_azp support
- x_q, x_scale = per_token_quant_int8(x)
-
- return int8_scaled_mm(
- x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
- )
diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
index 62073e38c..662c70c34 100644
--- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
+++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
@@ -1,5 +1,7 @@
import logging
+import torch
+
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
logger = logging.getLogger(__name__)
@@ -13,6 +15,7 @@ def _compute_enable_deep_gemm():
try:
import deep_gemm
except ImportError:
+ logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
return False
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py
index caf323950..8643a3e36 100644
--- a/python/sglang/srt/layers/quantization/mxfp4.py
+++ b/python/sglang/srt/layers/quantization/mxfp4.py
@@ -843,18 +843,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
topk_weights = topk_weights.to(
torch.float32
) # aiter's moe_sorting requires topk_weights to be FP32
-
- if hasattr(torch, "float4_e2m1fn_x2"):
- w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
- w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
- else:
- w13_weight = layer.w13_weight
- w2_weight = layer.w2_weight
-
output = fused_moe(
x,
- w13_weight,
- w2_weight,
+ layer.w13_weight,
+ layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
diff --git a/python/sglang/srt/layers/quantization/quark/quark_moe.py b/python/sglang/srt/layers/quantization/quark/quark_moe.py
index 1f8a1abfe..f6e750a2c 100644
--- a/python/sglang/srt/layers/quantization/quark/quark_moe.py
+++ b/python/sglang/srt/layers/quantization/quark/quark_moe.py
@@ -183,17 +183,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
moe_runner_config = self.moe_runner_config
topk_weights, topk_ids, _ = topk_output
- if hasattr(torch, "float4_e2m1fn_x2"):
- w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
- w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
- else:
- w13_weight = layer.w13_weight
- w2_weight = layer.w2_weight
-
output = fused_moe(
x,
- w13_weight,
- w2_weight,
+ layer.w13_weight,
+ layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py
index 158ae6561..e95247041 100644
--- a/python/sglang/srt/layers/quantization/w4afp8.py
+++ b/python/sglang/srt/layers/quantization/w4afp8.py
@@ -19,6 +19,10 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import is_npu, set_weight_attrs
+_is_npu = is_npu()
+if not _is_npu:
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
+
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py
index 17a79190d..5ccb0259d 100644
--- a/python/sglang/srt/layers/quantization/w8a8_int8.py
+++ b/python/sglang/srt/layers/quantization/w8a8_int8.py
@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x.dtype,
True, # is_vnni
)
+
x_q, x_scale = per_token_quant_int8(x)
- x_q_2d = x_q.view(-1, x_q.shape[-1])
- x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
- output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
-
- output = int8_scaled_mm(
- x_q_2d,
- layer.weight,
- x_scale_2d,
- layer.weight_scale,
- out_dtype=x.dtype,
- bias=bias,
+ return int8_scaled_mm(
+ x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
- return output.view(output_shape)
-
class W8A8Int8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8.
@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
- layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8LinearMethodMTImpl:
@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
- layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py
index 2c7267529..f0e9e5a7b 100644
--- a/python/sglang/srt/layers/rotary_embedding.py
+++ b/python/sglang/srt/layers/rotary_embedding.py
@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
- get_compiler_backend,
is_cpu,
is_cuda,
is_hip,
@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
- from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
-else:
- FusedSetKVBufferArg = None
-
+ from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
if _use_aiter:
from aiter.rotary_embedding import get_rope as aiter_get_rope
if is_npu():
import torch_npu
- NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
- NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
-
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
- fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
- assert (
- fused_set_kv_buffer_arg is None
- ), "fused_set_kv_buffer_arg is not supported for native implementation"
-
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
- fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-npu implementation of forward()."""
- assert (
- fused_set_kv_buffer_arg is None
- ), "fused_set_kv_buffer_arg is not supported for npu implementation"
+ import os
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
- return self.forward_native(
- positions, query, key, offsets, fused_set_kv_buffer_arg
- )
+ return self.forward_native(positions, query, key, offsets)
else:
rotary_mode = "half"
if self.is_neox_style:
@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
- fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
- assert (
- fused_set_kv_buffer_arg is None
- ), "fused_set_kv_buffer_arg is not supported for cpu implementation"
-
positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
self.is_neox_style,
)
else:
- return self.forward_native(
- positions, query, key, offsets, fused_set_kv_buffer_arg
- )
+ return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
- fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
+ fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace(
@@ -1059,7 +1035,7 @@ class MRotaryEmbedding(RotaryEmbedding):
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
)
- @torch.compile(dynamic=True, backend=get_compiler_backend())
+ @torch.compile(dynamic=True)
def forward(
self,
positions: torch.Tensor,
@@ -1207,7 +1183,7 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
- elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
+ elif model_type == "qwen2_vl":
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
@@ -1918,30 +1894,17 @@ def apply_rotary_pos_emb_npu(
sin: torch.Tensor,
unsqueeze_dim=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
- """Ascend implementation equivalent to apply_rotary_pos_emb_native.
-
- Args:
- q: [num_tokens, num_heads, head_size]
- k: [num_tokens, num_kv_heads, head_size]
- cos: [num_tokens, head_size]
- sin: [num_tokens, head_size]
- """
- if (
- cos.dim() != 2
- or q.dim() != 3
- or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
- or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
- ):
- # Note: num_heads and head_size of q must be less than 1000 and 896, respectively
+ if q.shape[1] != 128:
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
- cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
- sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
- q = q.unsqueeze(0)
- k = k.unsqueeze(0)
- q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
- k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
- q_embed = q_embed.squeeze(0)
- k_embed = k_embed.squeeze(0)
+ cos = cos.unsqueeze(unsqueeze_dim)
+ cos = torch.transpose(cos, 1, 2)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ sin = torch.transpose(sin, 1, 2)
+ q = torch.transpose(q, 1, 2)
+ k = torch.transpose(k, 1, 2)
+ q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
+ q_embed = torch.transpose(q_embed, 1, 2)
+ k_embed = torch.transpose(k_embed, 1, 2)
return q_embed, k_embed
diff --git a/python/sglang/srt/layers/utils.py b/python/sglang/srt/layers/utils.py
index 45e154791..d79ccc663 100644
--- a/python/sglang/srt/layers/utils.py
+++ b/python/sglang/srt/layers/utils.py
@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
return None
-def pad_or_narrow_weight(
- loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
-) -> torch.Tensor:
- # Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
- valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
-
- if valid_size > 0:
- loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
- pad_shape = list(loaded_weight.shape)
- pad_shape[input_dim] = shard_size - valid_size
- pad = torch.zeros(
- pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
- )
- return torch.cat([loaded_slice, pad], dim=input_dim)
-
- # All padding
- pad_shape = list(loaded_weight.shape)
- pad_shape[input_dim] = shard_size
- return torch.zeros(
- pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
- )
-
-
class PPMissingLayer(torch.nn.Identity):
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
index 1767c5ee4..951393929 100644
--- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
+++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
@@ -5,7 +5,7 @@ import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
-from sglang.srt.utils import cached_triton_kernel
+from sglang.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
diff --git a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
index e0ef41fb7..8b170bfa4 100644
--- a/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
+++ b/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
@@ -3,7 +3,7 @@ import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
-from sglang.srt.utils import cached_triton_kernel
+from sglang.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py
index a78e140bb..384dceb31 100644
--- a/python/sglang/srt/managers/cache_controller.py
+++ b/python/sglang/srt/managers/cache_controller.py
@@ -275,17 +275,43 @@ class HiCacheController:
and self.storage_config.tp_rank != 0
)
- # Use storage backend factory for dynamic backend creation
- from sglang.srt.mem_cache.storage import StorageBackendFactory
+ if storage_backend == "file":
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile
- try:
- self.storage_backend = StorageBackendFactory.create_backend(
- storage_backend, self.storage_config, self.mem_pool_host
+ self.storage_backend = HiCacheFile(self.storage_config)
+ elif storage_backend == "nixl":
+ from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
+
+ self.storage_backend = HiCacheNixl()
+ elif storage_backend == "mooncake":
+ from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
+ MooncakeStore,
)
- except ValueError as e:
- raise ValueError(f"Failed to create storage backend: {e}") from e
- self.storage_backend.register_mem_pool_host(self.mem_pool_host)
+ self.storage_backend = MooncakeStore(self.storage_config)
+ self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
+ assert self.mem_pool_host.layout == "page_first"
+ elif storage_backend == "hf3fs":
+ from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
+ HiCacheHF3FS,
+ )
+
+ if self.mem_pool_host.layout == "page_first":
+ bytes_per_page = (
+ mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
+ )
+ elif self.mem_pool_host.layout == "layer_first":
+ bytes_per_page = (
+ mem_pool_host.get_size_per_token() * mem_pool_host.page_size
+ )
+ dtype = mem_pool_host.dtype
+ self.storage_backend = HiCacheHF3FS.from_env_config(
+ bytes_per_page, dtype, self.storage_config
+ )
+ else:
+ raise NotImplementedError(
+ f"Unsupported storage backend: {storage_backend}"
+ )
self.enable_storage = True
# todo: threshold policy for prefetching
@@ -309,10 +335,18 @@ class HiCacheController:
# Select the get and set functions
self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set
-
- if self.storage_backend_type in ["hf3fs", "mooncake"]:
- self.page_get_func = self._page_get_zero_copy
- self.page_set_func = self._page_set_zero_copy
+ self.batch_exists_func = self.storage_backend.batch_exists
+ self.is_3fs_zerocopy = (
+ self.storage_backend_type == "hf3fs"
+ and self.mem_pool_host.layout == "page_first"
+ )
+ if self.storage_backend_type == "mooncake":
+ self.page_get_func = self._mooncake_page_get
+ self.page_set_func = self._mooncake_page_set
+ elif self.is_3fs_zerocopy:
+ self.page_get_func = self._3fs_zero_copy_page_get
+ self.page_set_func = self._3fs_zero_copy_page_set
+ self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.device = self.mem_pool_device.device
self.layer_num = self.mem_pool_device.layer_num
@@ -436,6 +470,7 @@ class HiCacheController:
host_indices = self.mem_pool_host.alloc(len(device_indices))
if host_indices is None:
return None
+ self.mem_pool_host.protect_write(host_indices)
self.write_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
@@ -459,6 +494,7 @@ class HiCacheController:
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
+ self.mem_pool_host.complete_io(op.host_indices)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
@@ -482,6 +518,7 @@ class HiCacheController:
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None:
return None
+ self.mem_pool_host.protect_load(host_indices)
self.load_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
@@ -526,6 +563,7 @@ class HiCacheController:
self.io_backend,
)
producer_event.complete(i)
+ self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
@@ -543,16 +581,29 @@ class HiCacheController:
)
return producer_id
- def evict_device(self, device_indices: torch.Tensor) -> int:
- self.mem_pool_device_allocator.free(device_indices)
- return len(device_indices)
+ def evict_device(
+ self, device_indices: torch.Tensor, host_indices: torch.Tensor
+ ) -> int:
+ if self.mem_pool_host.is_synced(host_indices):
+ self.mem_pool_device_allocator.free(device_indices)
+ self.mem_pool_host.update_backup(host_indices)
+ return len(device_indices)
+ else:
+ raise ValueError(
+ f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
+ )
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
if not backup_only:
raise ValueError("Other eviction policies are not supported yet.")
- self.mem_pool_host.free(host_indices)
- return len(host_indices)
+ if self.mem_pool_host.is_backup(host_indices):
+ self.mem_pool_host.free(host_indices)
+ return len(host_indices)
+ else:
+ raise ValueError(
+ f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
+ )
def prefetch(
self,
@@ -579,19 +630,42 @@ class HiCacheController:
for chunk in chunks:
self.host_mem_release_queue.put(chunk)
- def _page_get_zero_copy(self, operation, hash_values, host_indices):
- results = self.storage_backend.batch_get_v1(hash_values, host_indices)
- inc = 0
- for i in range(len(hash_values)):
- if not results[i]:
- logger.warning(
- f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
- )
- break
- inc += self.page_size
- operation.increment(inc)
+ def _3fs_zero_copy_batch_exists(self, batch_hashes):
+ _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
+ hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
+ return hit_page_num
+
+ def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
+ hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
+ hash_values, host_indices
+ )
+ page_data = self.storage_backend.batch_get(hashes, dsts)
+ if page_data:
+ inc = self.page_size * len(hashes) // factor
+ operation.increment(inc)
+ else:
+ logger.warning(
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
+ )
+
+ def _mooncake_page_get(self, operation, hash_values, host_indices):
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
+ hash_values,
+ host_indices,
+ self.storage_config.tp_rank,
+ )
+ get_result = self.storage_backend.batch_get(
+ key_strs,
+ target_locations=buffer_ptrs,
+ target_sizes=buffer_sizes,
+ )
+ if get_result != len(hash_values):
+ logger.warning(
+ f"Prefetch operation {operation.request_id} failed or partially failed."
+ )
+ if get_result != 0:
+ operation.increment(get_result * self.page_size)
- # todo: deprecate
def _generic_page_get(self, operation, hash_values, host_indices):
dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
@@ -681,7 +755,7 @@ class HiCacheController:
batch_tokens[i : i + self.page_size], last_hash
)
batch_hashes.append(last_hash)
- hit_page_num = self.storage_backend.batch_exists(batch_hashes)
+ hit_page_num = self.batch_exists_func(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes):
@@ -750,16 +824,34 @@ class HiCacheController:
self.backup_queue.put(operation)
return operation.id
- # todo: deprecate
+ # non-zero copy
def _generic_page_set(self, hash_values, host_indices) -> bool:
data = [
- self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
+ self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values))
]
return self.storage_backend.batch_set(hash_values, data)
- def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
- return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
+ # zero copy
+ def _mooncake_page_set(self, hash_values, host_indices) -> bool:
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
+ hash_values,
+ host_indices,
+ self.storage_config.tp_rank,
+ )
+ success = self.storage_backend.batch_set(
+ key_strs,
+ target_locations=buffer_ptrs,
+ target_sizes=buffer_sizes,
+ )
+ return success
+
+ # zero copy
+ def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
+ hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
+ hash_values, host_indices
+ )
+ return self.storage_backend.batch_set(hashes, dsts)
# Backup batch by batch
def _page_backup(self, operation):
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 436d62f27..86cfcf945 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -35,7 +35,6 @@ else:
Image = Any
-# Parameters for a session
@dataclass
class SessionParams:
id: Optional[str] = None
@@ -133,23 +132,18 @@ class GenerateReqInput:
# Conversation id used for tracking requests
conversation_id: Optional[str] = None
+ # Label for the request
+ label: Optional[str] = None
+
# Priority for the request
priority: Optional[int] = None
- # Extra key for classifying the request (e.g. cache_salt)
- extra_key: Optional[Union[List[str], str]] = None
-
- # Whether to disallow logging for this request (e.g. due to ZDR)
- no_logs: bool = False
-
- # For custom metric labels
- custom_labels: Optional[Dict[str, str]] = None
-
- # (Deprecated, please use custom_labels) Label for the request
- label: Optional[str] = None
- # (Internal) Whether to return bytes for image generation
+ # Image gen grpc migration
return_bytes: bool = False
+ # For customer metric labels
+ customer_labels: Optional[Dict[str, str]] = None
+
def contains_mm_input(self) -> bool:
return (
has_valid_data(self.image_data)
@@ -548,11 +542,8 @@ class GenerateReqInput:
self.data_parallel_rank if self.data_parallel_rank is not None else None
),
conversation_id=self.conversation_id,
- priority=self.priority,
- extra_key=self.extra_key,
- no_logs=self.no_logs,
- custom_labels=self.custom_labels,
label=self.label,
+ priority=self.priority,
return_bytes=self.return_bytes,
)
@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
# For dp balance
dp_balance_id: int = -1
+ # Label for the request
+ label: Optional[str] = None
+
# Priority for the request
priority: Optional[int] = None
- # Extra key for classifying the request (e.g. cache_salt)
- extra_key: Optional[str] = None
-
- # Whether to disallow logging for this request (e.g. due to ZDR)
- no_logs: bool = False
+ # Image gen grpc migration
+ return_bytes: bool = False
# tracing context
trace_context: Optional[Dict] = None
- # (Deprecated, please use custom_labels) Label for the request
- label: Optional[str] = None
- # (Internal) Whether to return bytes for image generation
- return_bytes: bool = False
-
@dataclass
class BatchTokenizedGenerateReqInput:
diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py
index 41de295af..f495904d5 100644
--- a/python/sglang/srt/managers/mm_utils.py
+++ b/python/sglang/srt/managers/mm_utils.py
@@ -507,7 +507,6 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
- use_deepstack: bool = False,
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
@@ -523,7 +522,7 @@ def embed_mm_inputs(
Returns:
Combined embedding tensor with multimodal content integrated
"""
- other_info = {}
+
if mm_inputs_list is None:
return None
@@ -533,7 +532,7 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
- embeddings, masks, deepstack_embeddings = [], [], []
+ embeddings, masks = [], []
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for modality in Modality.all():
@@ -579,12 +578,6 @@ def embed_mm_inputs(
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
-
- if use_deepstack and embedding is not None:
- embedding, deepstack_embedding = (
- multimodal_model.separate_deepstack_embeds(embedding)
- )
- deepstack_embeddings += [deepstack_embedding]
embeddings += [embedding]
masks += [mask]
@@ -598,37 +591,13 @@ def embed_mm_inputs(
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
-
- # deepstack embedding
- if use_deepstack:
- num_deepstack_embeddings = (
- len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
- )
- deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
- inputs_embeds.shape[-1] * num_deepstack_embeddings,
- )
-
- input_deepstack_embeds = torch.zeros(
- deepstack_embedding_shape,
- device=inputs_embeds.device,
- dtype=inputs_embeds.dtype,
- )
-
- other_info["input_deepstack_embeds"] = input_deepstack_embeds
-
- for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
+ for embedding, mask in zip(embeddings, masks):
if embedding is None or mask is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
-
- if use_deepstack:
- input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
- inputs_embeds.device, inputs_embeds.dtype
- )
-
- return inputs_embeds, other_info
+ return inputs_embeds
def general_mm_embed_routine(
@@ -640,7 +609,6 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
- use_deepstack: bool = False,
**kwargs,
) -> torch.Tensor:
"""
@@ -652,7 +620,6 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
- use_deepstack: Whether to use deepstack embeddings
**kwargs: Additional arguments passed to language model
Returns:
@@ -678,20 +645,16 @@ def general_mm_embed_routine(
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
- inputs_embeds, other_info = embed_mm_inputs(
+ inputs_embeds = embed_mm_inputs(
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
- multimodal_model=multimodal_model,
input_embedding=embed_tokens,
+ multimodal_model=multimodal_model,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
- use_deepstack=use_deepstack,
)
- # add for qwen3_vl deepstack
- if use_deepstack:
- kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
diff --git a/python/sglang/srt/managers/multimodal_processor.py b/python/sglang/srt/managers/multimodal_processor.py
index 7826241d0..bc060a5b3 100644
--- a/python/sglang/srt/managers/multimodal_processor.py
+++ b/python/sglang/srt/managers/multimodal_processor.py
@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {}
-def import_processors(package_name: str):
+def import_processors():
+ package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
diff --git a/python/sglang/srt/managers/overlap_utils.py b/python/sglang/srt/managers/overlap_utils.py
deleted file mode 100644
index d512ae7ec..000000000
--- a/python/sglang/srt/managers/overlap_utils.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import torch
-
-from sglang.srt.managers.schedule_batch import ModelWorkerBatch
-from sglang.srt.utils import get_compiler_backend
-
-
-@torch.compile(dynamic=True, backend=get_compiler_backend())
-def _resolve_future_token_ids(input_ids, future_token_ids_map):
- input_ids[:] = torch.where(
- input_ids < 0,
- future_token_ids_map[torch.clamp(-input_ids, min=0)],
- input_ids,
- )
-
-
-class FutureMap:
- def __init__(
- self,
- max_running_requests: int,
- device: torch.device,
- ):
- self.future_ct = 0
- # A factor of 3 is used to avoid collision in the circular buffer.
- self.future_limit = max_running_requests * 3
- # A factor of 5 is used to ensure the buffer is large enough.
- self.future_buffer_len = max_running_requests * 5
- self.device = device
-
- self.token_ids_buf = torch.empty(
- (self.future_buffer_len,), dtype=torch.int64, device=self.device
- )
-
- def update_ct(self, bs: int) -> int:
- """Update the circular buffer pointer and return the current pointer."""
- cur_future_ct = self.future_ct
- self.future_ct = (cur_future_ct + bs) % self.future_limit
- return cur_future_ct
-
- def resolve_future(self, model_worker_batch: ModelWorkerBatch):
- input_ids = model_worker_batch.input_ids
- _resolve_future_token_ids(input_ids, self.token_ids_buf)
-
- def update_next_future(self, future_ct: int, bs: int):
- return torch.arange(
- -(future_ct + 1),
- -(future_ct + 1 + bs),
- -1,
- dtype=torch.int64,
- device=self.device,
- )
-
- def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
- self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index 97f5d00ff..c47f39788 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
-from sglang.srt.sampling.sampling_params import SamplingParams
+from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
- from sglang.srt.speculative.ngram_utils import NgramVerifyInput
+ from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
- "enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
+ "nsa_prefill",
+ "nsa_decode",
]
# Put some global args for easy access
@@ -492,7 +493,7 @@ class Req:
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
- # extra key for classifying the request (e.g. cache_salt)
+ # extra key for classifying the request (e.g. lora_id, cache_salt)
if lora_id is not None:
extra_key = (
extra_key or ""
@@ -608,8 +609,6 @@ class Req:
) = None
self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
- self.output_topk_p = None
- self.output_topk_index = None
# Embedding (return values)
self.embedding = None
@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
- None
- )
+ spec_info: Optional[
+ Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
+ ] = None
# Whether to return hidden states
return_hidden_states: bool = False
@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if (
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
- or self.spec_algorithm.is_ngram()
+ or self.spec_algorithm.is_lookahead()
):
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info:
- if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
- has_been_filtered = False
- else:
- has_been_filtered = True
- self.spec_info.filter_batch(
- new_indices=keep_indices_device,
- has_been_filtered=has_been_filtered,
- )
+ self.spec_info.filter_batch(keep_indices_device)
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
- None
- )
+ spec_info: Optional[
+ Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
+ ] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1
diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py
index 755ac29c8..07c2a06d9 100644
--- a/python/sglang/srt/managers/schedule_policy.py
+++ b/python/sglang/srt/managers/schedule_policy.py
@@ -318,6 +318,7 @@ class PrefillAdder:
new_token_ratio: float,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
+ max_prefill_bs: Optional[int],
mixed_with_decode_tokens: int = 0,
priority_scheduling_preemption_threshold: int = 0,
):
@@ -358,6 +359,10 @@ class PrefillAdder:
priority_scheduling_preemption_threshold
)
+ self.max_prefill_bs = (
+ max_prefill_bs if max_prefill_bs is not None else 2147483647
+ )
+
def _get_running_request_total_token_offset(self, req: Req) -> int:
return (
min(
@@ -549,6 +554,9 @@ class PrefillAdder:
def add_one_req(
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
):
+ if len(self.can_run_list) >= self.max_prefill_bs:
+ return AddReqResult.OTHER
+
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req)
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 893a0b0a1..8d764c3e1 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -44,9 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
-from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
- DecodeKVCacheOffloadManager,
-)
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
@@ -262,7 +259,7 @@ class Scheduler(
self.enable_metrics_for_all_schedulers = (
server_args.enable_metrics_for_all_schedulers
)
- self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
+ self.enable_kv_cache_events = server_args.kv_events_config is not None
self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
@@ -388,10 +385,10 @@ class Scheduler(
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
- elif self.spec_algorithm.is_ngram():
- from sglang.srt.speculative.ngram_worker import NGRAMWorker
+ elif self.spec_algorithm.is_lookahead():
+ from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker
- self.draft_worker = NGRAMWorker(
+ self.draft_worker = LOOKAHEADWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
@@ -556,11 +553,9 @@ class Scheduler(
# Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank)
+ self.init_kv_events(server_args.kv_events_config)
self.init_dp_balance(dp_balance_meta)
- if self.enable_kv_cache_events:
- self.init_kv_events(server_args.kv_events_config)
-
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
@@ -618,6 +613,8 @@ class Scheduler(
]
)
+ self.max_prefill_bs = server_args.max_prefill_bs
+
def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference:
@@ -758,24 +755,6 @@ class Scheduler(
eviction_policy=server_args.radix_eviction_policy,
)
- if (
- server_args.disaggregation_mode == "decode"
- and server_args.disaggregation_decode_enable_offload_kvcache
- ):
- self.decode_offload_manager = DecodeKVCacheOffloadManager(
- req_to_token_pool=self.req_to_token_pool,
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
- tp_group=(
- self.attn_tp_cpu_group
- if self.server_args.enable_dp_attention
- else self.tp_cpu_group
- ),
- tree_cache=self.tree_cache,
- server_args=self.server_args,
- )
- else:
- self.decode_offload_manager = None
-
self.decode_mem_cache_buf_multiplier = (
1
if self.spec_algorithm.is_none()
@@ -806,7 +785,7 @@ class Scheduler(
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
- hidden_states_dtype=self.model_config.dtype,
+ dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
@@ -826,7 +805,7 @@ class Scheduler(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=(
None
- if self.draft_worker is None or self.spec_algorithm.is_ngram()
+ if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -855,7 +834,7 @@ class Scheduler(
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
- hidden_states_dtype=self.model_config.dtype,
+ dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
@@ -863,7 +842,7 @@ class Scheduler(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=(
None
- if self.draft_worker is None or self.spec_algorithm.is_ngram()
+ if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -1832,6 +1811,7 @@ class Scheduler(
self.new_token_ratio,
self.max_prefill_tokens,
self.chunked_prefill_size,
+ self.max_prefill_bs,
running_bs if self.is_mixed_chunk else 0,
self.priority_scheduling_preemption_threshold,
)
diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
index 5d8545dac..aa060af8a 100644
--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py
+++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
@@ -250,13 +250,7 @@ class SchedulerOutputProcessorMixin:
req.check_finished()
if req.finished():
- if self.server_args.disaggregation_decode_enable_offload_kvcache:
- # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
- if not self.decode_offload_manager.offload_kv_cache(req):
- self.tree_cache.cache_finished_req(req)
- else:
- self.tree_cache.cache_finished_req(req)
-
+ self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time()
if req.return_logprob and batch.spec_algorithm.is_none():
diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py
index a71214ac0..e7ac8452d 100644
--- a/python/sglang/srt/managers/scheduler_profiler_mixin.py
+++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
- stage_str = f" for {stage.name}" if stage else ""
+ stage_str = f" for {stage.__str__()}" if stage else ""
logger.info(
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
)
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
if not Path(self.torch_profiler_output_dir).exists():
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
- stage_suffix = f"-{stage.name}" if stage else ""
+ stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
- self.stop_profile(stage=ForwardMode.EXTEND)
+ self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
recv_req.profile_by_stage,
recv_req.profile_id,
)
- return self.start_profile()
+ return self.start_profile(True)
else:
return self.stop_profile()
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index cc4b8c038..09059277d 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -185,7 +185,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
if self.model_config.is_multimodal:
- import_processors("sglang.srt.multimodal.processors")
+ import_processors()
try:
_processor = get_processor(
server_args.tokenizer_path,
@@ -320,8 +320,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
}
- if server_args.tokenizer_metrics_allowed_custom_labels:
- for label in server_args.tokenizer_metrics_allowed_custom_labels:
+ if server_args.tokenizer_metrics_allowed_customer_labels:
+ for label in server_args.tokenizer_metrics_allowed_customer_labels:
labels[label] = ""
self.metrics_collector = TokenizerMetricsCollector(
server_args=server_args,
@@ -750,7 +750,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority,
- extra_key=obj.extra_key,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
@@ -1633,10 +1632,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else 0
)
- custom_labels = getattr(state.obj, "custom_labels", None)
+ customer_labels = getattr(state.obj, "customer_labels", None)
labels = (
- {**self.metrics_collector.labels, **custom_labels}
- if custom_labels
+ {**self.metrics_collector.labels, **customer_labels}
+ if customer_labels
else self.metrics_collector.labels
)
if (
diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py
index 0d3f76658..6453b5675 100644
--- a/python/sglang/srt/managers/tp_worker.py
+++ b/python/sglang/srt/managers/tp_worker.py
@@ -91,6 +91,7 @@ class TpModelWorker:
else server_args.speculative_draft_model_revision
),
is_draft_model=is_draft_worker,
+ tp_rank=tp_rank,
)
self.model_runner = ModelRunner(
diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py
index 9ca68b0b8..d0b5e586d 100644
--- a/python/sglang/srt/managers/tp_worker_overlap_thread.py
+++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py
@@ -36,11 +36,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
-from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
-from sglang.srt.utils import DynamicGradMode
+from sglang.srt.utils import DynamicGradMode, get_compiler_backend
from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
@@ -49,6 +48,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@torch.compile(dynamic=True, backend=get_compiler_backend())
+def resolve_future_token_ids(input_ids, future_token_ids_map):
+ input_ids[:] = torch.where(
+ input_ids < 0,
+ future_token_ids_map[torch.clamp(-input_ids, min=0)],
+ input_ids,
+ )
+
+
class TpModelWorkerClient:
"""A tensor parallel model worker."""
@@ -71,7 +79,11 @@ class TpModelWorkerClient:
self.gpu_id = gpu_id
# Init future mappings
- self.future_map = FutureMap(self.max_running_requests, self.device)
+ self.future_token_ids_ct = 0
+ self.future_token_ids_limit = self.max_running_requests * 3
+ self.future_token_ids_map = torch.empty(
+ (self.max_running_requests * 5,), dtype=torch.int64, device=self.device
+ )
# Launch threads
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
@@ -141,7 +153,7 @@ class TpModelWorkerClient:
batch_lists: List = [None] * 2
while True:
- model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
+ model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
if not model_worker_batch:
break
@@ -157,7 +169,8 @@ class TpModelWorkerClient:
copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input
- self.future_map.resolve_future(model_worker_batch)
+ input_ids = model_worker_batch.input_ids
+ resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward
logits_output, next_token_ids, can_run_cuda_graph = (
@@ -174,9 +187,9 @@ class TpModelWorkerClient:
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long)
-
- # store the future indices into future map
- self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
+ self.future_token_ids_map[
+ future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
+ ] = next_token_ids
# Copy results to the CPU
if model_worker_batch.return_logprob:
@@ -242,14 +255,20 @@ class TpModelWorkerClient:
sync_event.record(self.scheduler_stream)
# Push a new batch to the queue
- bs = len(model_worker_batch.seq_lens)
- cur_future_map_ct = self.future_map.update_ct(bs)
- self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
+ self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
- # get this forward batch's future token ids
- future_next_token_ids = self.future_map.update_next_future(
- cur_future_map_ct, bs
+ # Allocate output future objects
+ bs = len(model_worker_batch.seq_lens)
+ future_next_token_ids = torch.arange(
+ -(self.future_token_ids_ct + 1),
+ -(self.future_token_ids_ct + 1 + bs),
+ -1,
+ dtype=torch.int64,
+ device=self.device,
)
+ self.future_token_ids_ct = (
+ self.future_token_ids_ct + bs
+ ) % self.future_token_ids_limit
return None, future_next_token_ids, False
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
diff --git a/python/sglang/srt/mem_cache/allocator_ascend.py b/python/sglang/srt/mem_cache/allocator_ascend.py
index 2af138a6c..44d7b61c5 100644
--- a/python/sglang/srt/mem_cache/allocator_ascend.py
+++ b/python/sglang/srt/mem_cache/allocator_ascend.py
@@ -79,37 +79,48 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
)
num_new_pages = (
- (
- (seq_lens + self.page_size - 1) // self.page_size
- - (prefix_lens + self.page_size - 1) // self.page_size
- )
- .sum()
- .item()
- )
- if self.need_sort and num_new_pages > len(self.free_pages):
+ (seq_lens + self.page_size - 1) // self.page_size
+ - (prefix_lens + self.page_size - 1) // self.page_size
+ ).sum()
+ num_new_pages_item = num_new_pages.item()
+ if self.need_sort and num_new_pages_item > len(self.free_pages):
self.merge_and_sort_free()
- if num_new_pages > len(self.free_pages):
+ if num_new_pages_item > len(self.free_pages):
return None
out_indices = torch.empty(
- (extend_num_tokens,), dtype=torch.int32, device=self.device
+ (extend_num_tokens,), dtype=torch.int64, device=self.device
)
- alloc_extend_kernel_ascend(
- prefix_lens,
- seq_lens,
- last_loc,
- self.free_pages,
- out_indices,
- self.page_size,
- self.device,
- )
+ if num_new_pages_item < 200:
+ import sgl_kernel_npu
+
+ torch.ops.npu.alloc_extend(
+ prefix_lens,
+ seq_lens,
+ last_loc,
+ self.free_pages,
+ self.page_size,
+ out_indices,
+ num_new_pages,
+ )
+
+ else:
+ alloc_extend_kernel_ascend(
+ prefix_lens,
+ seq_lens,
+ last_loc,
+ self.free_pages,
+ out_indices,
+ self.page_size,
+ self.device,
+ )
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
- self.free_pages = self.free_pages[num_new_pages:]
+ self.free_pages = self.free_pages[num_new_pages_item:]
return out_indices
def alloc_decode(
diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py
index 8b21446b9..6ec077db5 100644
--- a/python/sglang/srt/mem_cache/hicache_storage.py
+++ b/python/sglang/srt/mem_cache/hicache_storage.py
@@ -7,8 +7,6 @@ from typing import Any, List, Optional
import torch
-from sglang.srt.mem_cache.memory_pool_host import HostKVCache
-
logger = logging.getLogger(__name__)
@@ -34,46 +32,15 @@ class HiCacheStorageConfig:
extra_config: Optional[dict] = None
-@dataclass
-class HiCacheStorageExtraInfo:
- extra_info: Optional[dict] = None
-
-
class HiCacheStorage(ABC):
"""
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
+ # todo, potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
- def register_mem_pool_host(self, mem_pool_host: HostKVCache):
- self.mem_pool_host = mem_pool_host
-
- def batch_get_v1(
- self,
- keys: List[str],
- host_indices: torch.Tensor,
- extra_info: Optional[HiCacheStorageExtraInfo] = None,
- ) -> List[bool]:
- """
- Retrieve values for multiple keys.
- Returns a list of tensors or None for each key.
- """
- pass
-
- def batch_set_v1(
- self,
- keys: List[str],
- host_indices: torch.Tensor,
- extra_info: Optional[HiCacheStorageExtraInfo] = None,
- ) -> List[bool]:
- """
- Retrieve values for multiple keys.
- Returns a list of tensors or None for each key.
- """
- pass
-
@abstractmethod
def get(
self,
@@ -87,7 +54,6 @@ class HiCacheStorage(ABC):
"""
pass
- # TODO: Deprecate
@abstractmethod
def batch_get(
self,
@@ -115,7 +81,6 @@ class HiCacheStorage(ABC):
"""
pass
- # TODO: Deprecate
@abstractmethod
def batch_set(
self,
@@ -138,7 +103,6 @@ class HiCacheStorage(ABC):
"""
pass
- # TODO: Use a finer-grained return type (e.g., List[bool])
def batch_exists(self, keys: List[str]) -> int:
"""
Check if the keys exist in the storage.
@@ -150,9 +114,6 @@ class HiCacheStorage(ABC):
return i
return len(keys)
- def clear(self) -> None:
- pass
-
def get_stats(self):
return None
diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py
index 75ff08fd6..9dfe9aca0 100644
--- a/python/sglang/srt/mem_cache/hiradix_cache.py
+++ b/python/sglang/srt/mem_cache/hiradix_cache.py
@@ -48,9 +48,9 @@ class HiRadixCache(RadixCache):
if hicache_io_backend == "direct":
if hicache_mem_layout == "page_first":
- hicache_mem_layout = "page_first_direct"
+ hicache_mem_layout = "layer_first"
logger.warning(
- "Page first layout is not supported with direct IO backend, switching to page first direct layout"
+ "Page first layout is not supported with direct IO backend, switching to layer first layout"
)
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
def _evict_backuped(self, node: TreeNode):
# evict a node already written to host
- num_evicted = self.cache_controller.evict_device(node.value)
+ num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
assert num_evicted > 0
self.evictable_size_ -= num_evicted
node.value = None
@@ -576,6 +576,8 @@ class HiRadixCache(RadixCache):
written_indices,
hash_value[: min_completed_tokens // self.page_size],
)
+ if len(written_indices):
+ self.cache_controller.mem_pool_host.update_prefetch(written_indices)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.append_host_mem_release(
@@ -773,6 +775,7 @@ class HiRadixCache(RadixCache):
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
node.value = value[:prefix_len]
+ self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value)
else:
self._inc_hit_count(node, chunked)
@@ -782,6 +785,7 @@ class HiRadixCache(RadixCache):
new_node = self._split_node(node.key, node, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
+ self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
else:
self._inc_hit_count(new_node, chunked)
diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py
index 5b0f8a714..3627b3fb5 100644
--- a/python/sglang/srt/mem_cache/memory_pool.py
+++ b/python/sglang/srt/mem_cache/memory_pool.py
@@ -15,6 +15,8 @@ limitations under the License.
from __future__ import annotations
+from sglang.srt.layers.attention.nsa import index_buf_accessor
+from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
@@ -37,6 +39,7 @@ import triton
import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
+from sglang.srt.layers.attention.nsa.utils import NSA_KV_CACHE_STORE_FP8
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
@@ -1030,6 +1033,8 @@ class MLATokenToKVPool(KVCache):
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
+ use_nsa: bool = False,
+ override_kv_cache_dim: Optional[int] = None,
):
super().__init__(
size,
@@ -1044,6 +1049,16 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
+ self.use_nsa = use_nsa
+ # TODO do not hardcode
+ self.kv_cache_dim = (
+ 656
+ if use_nsa and NSA_KV_CACHE_STORE_FP8
+ else (kv_lora_rank + qk_rope_head_dim)
+ )
+
+ if use_nsa and NSA_KV_CACHE_STORE_FP8:
+ assert self.dtype == torch.float8_e4m3fn, f"{self.dtype=}"
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
@@ -1067,7 +1082,7 @@ class MLATokenToKVPool(KVCache):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
- (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
+ (size + page_size, 1, self.kv_cache_dim),
dtype=self.store_dtype,
device=device,
)
@@ -1130,6 +1145,7 @@ class MLATokenToKVPool(KVCache):
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
+ assert not (self.use_nsa and NSA_KV_CACHE_STORE_FP8)
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
@@ -1147,16 +1163,28 @@ class MLATokenToKVPool(KVCache):
cache_k_rope: torch.Tensor,
):
layer_id = layer.layer_id
- if cache_k_nope.dtype != self.dtype:
- cache_k_nope = cache_k_nope.to(self.dtype)
- cache_k_rope = cache_k_rope.to(self.dtype)
- if self.store_dtype != self.dtype:
- cache_k_nope = cache_k_nope.view(self.store_dtype)
- cache_k_rope = cache_k_rope.view(self.store_dtype)
- set_mla_kv_buffer_triton(
- self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
- )
+ if self.use_nsa and NSA_KV_CACHE_STORE_FP8:
+ # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
+ # TODO no need to cat
+ cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
+ cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
+ cache_k = cache_k.view(self.store_dtype)
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
+ else:
+ if cache_k_nope.dtype != self.dtype:
+ cache_k_nope = cache_k_nope.to(self.dtype)
+ cache_k_rope = cache_k_rope.to(self.dtype)
+ if self.store_dtype != self.dtype:
+ cache_k_nope = cache_k_nope.view(self.store_dtype)
+ cache_k_rope = cache_k_rope.view(self.store_dtype)
+
+ set_mla_kv_buffer_triton(
+ self.kv_buffer[layer_id - self.start_layer],
+ loc,
+ cache_k_nope,
+ cache_k_rope,
+ )
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
@@ -1186,6 +1214,103 @@ class MLATokenToKVPool(KVCache):
torch.cuda.synchronize()
+class NSATokenToKVPool(MLATokenToKVPool):
+ def __init__(
+ self,
+ size: int,
+ page_size: int,
+ kv_lora_rank: int,
+ dtype: torch.dtype,
+ qk_rope_head_dim: int,
+ layer_num: int,
+ device: str,
+ index_head_dim: int,
+ enable_memory_saver: bool,
+ start_layer: Optional[int] = None,
+ end_layer: Optional[int] = None,
+ ):
+ super().__init__(
+ size,
+ page_size,
+ dtype,
+ kv_lora_rank,
+ qk_rope_head_dim,
+ layer_num,
+ device,
+ enable_memory_saver,
+ start_layer,
+ end_layer,
+ use_nsa=True,
+ )
+ # self.index_k_dtype = torch.float8_e4m3fn
+ # self.index_k_scale_dtype = torch.float32
+ self.index_head_dim = index_head_dim
+ # num head == 1 and head dim == 128 for index_k in NSA
+ assert index_head_dim == 128
+
+ self.quant_block_size = 128
+
+ assert self.page_size == 64
+ self.index_k_with_scale_buffer = [
+ torch.zeros(
+ # Layout:
+ # ref: test_attention.py :: kv_cache_cast_to_fp8
+ # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
+ # data: for page i,
+ # * buf[i, :page_size * head_dim] for fp8 data
+ # * buf[i, page_size * head_dim:].view(float32) for scale
+ (
+ (size + page_size + 1) // self.page_size,
+ self.page_size
+ * (index_head_dim + index_head_dim // self.quant_block_size * 4),
+ ),
+ dtype=torch.uint8,
+ device=device,
+ )
+ for _ in range(layer_num)
+ ]
+
+ def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
+ if self.layer_transfer_counter is not None:
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
+ return self.index_k_with_scale_buffer[layer_id - self.start_layer]
+
+ def get_index_k_continuous(
+ self,
+ layer_id: int,
+ seq_len: int,
+ page_indices: torch.Tensor,
+ ):
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
+ return index_buf_accessor.GetK.execute(
+ self, buf, seq_len=seq_len, page_indices=page_indices
+ )
+
+ def get_index_k_scale_continuous(
+ self,
+ layer_id: int,
+ seq_len: int,
+ page_indices: torch.Tensor,
+ ):
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
+ return index_buf_accessor.GetS.execute(
+ self, buf, seq_len=seq_len, page_indices=page_indices
+ )
+
+ # TODO rename later (currently use diff name to avoid confusion)
+ def set_index_k_and_scale_buffer(
+ self,
+ layer_id: int,
+ loc: torch.Tensor,
+ index_k: torch.Tensor,
+ index_k_scale: torch.Tensor,
+ ) -> None:
+ buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
+ index_buf_accessor.SetKAndS.execute(
+ pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
+ )
+
+
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__(
self,
@@ -1194,6 +1319,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
+ index_head_dim: Optional[int],
layer_num: int,
device: str,
enable_memory_saver: bool,
@@ -1213,6 +1339,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
+ self.index_head_dim = index_head_dim
self.custom_mem_pool = None
@@ -1240,6 +1367,18 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype=self.store_dtype,
device=self.device,
)
+ if self.index_head_dim is not None:
+ self.index_k_buffer = torch.zeros(
+ (
+ layer_num,
+ self.size // self.page_size + 1,
+ self.page_size,
+ 1,
+ self.index_head_dim,
+ ),
+ dtype=self.store_dtype,
+ device=self.device,
+ )
self._finalize_allocation_log(size)
@@ -1251,6 +1390,10 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
kv_size_bytes += get_tensor_size_bytes(k_cache)
for v_cache in self.v_buffer:
kv_size_bytes += get_tensor_size_bytes(v_cache)
+ if self.index_head_dim is not None:
+ assert hasattr(self, "index_k_buffer")
+ for index_k_cache in self.index_k_buffer:
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
return kv_size_bytes
def get_kv_buffer(self, layer_id: int):
@@ -1277,6 +1420,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
+ def get_index_k_buffer(self, layer_id: int):
+ if self.layer_transfer_counter is not None:
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
+
+ if self.store_dtype != self.dtype:
+ return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
+ return self.index_k_buffer[layer_id - self.start_layer]
+
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
@@ -1289,6 +1440,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
]
+ if self.index_head_dim is not None:
+ kv_data_ptrs += [
+ self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
+ ]
+ kv_data_lens += [
+ self.index_k_buffer[i].nbytes for i in range(self.layer_num)
+ ]
+ kv_item_lens += [
+ self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
+ ]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
@@ -1325,6 +1486,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
cache_v.view(-1, 1, self.qk_rope_head_dim),
)
+ def set_index_k_buffer(
+ self,
+ layer_id: int,
+ loc: torch.Tensor,
+ index_k: torch.Tensor,
+ ):
+ if index_k.dtype != self.dtype:
+ index_k = index_k.to(self.dtype)
+
+ if self.store_dtype != self.dtype:
+ index_k = index_k.view(self.store_dtype)
+
+ torch_npu.npu_scatter_nd_update_(
+ self.index_k_buffer[layer_id - self.start_layer].view(
+ -1, 1, self.index_head_dim
+ ),
+ loc.view(-1, 1),
+ index_k.view(-1, 1, self.index_head_dim),
+ )
+
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py
index f6d655af0..079dc0a64 100644
--- a/python/sglang/srt/mem_cache/memory_pool_host.py
+++ b/python/sglang/srt/mem_cache/memory_pool_host.py
@@ -31,13 +31,27 @@ if not (_is_npu or _is_xpu):
logger = logging.getLogger(__name__)
-def synchronized(func):
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- with self.lock:
- return func(self, *args, **kwargs)
+class MemoryStateInt(IntEnum):
+ IDLE = 0
+ RESERVED = 1
+ PROTECTED = 2
+ SYNCED = 3
+ BACKUP = 4
- return wrapper
+
+def synchronized(debug_only=False):
+ def _decorator(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if (not debug_only) or self.debug:
+ with self.lock:
+ return func(self, *args, **kwargs)
+ else:
+ return True
+
+ return wrapper
+
+ return _decorator
class HostKVCache(abc.ABC):
@@ -96,6 +110,7 @@ class HostKVCache(abc.ABC):
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
+ self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear()
@abc.abstractmethod
@@ -125,7 +140,7 @@ class HostKVCache(abc.ABC):
raise NotImplementedError()
@abc.abstractmethod
- def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
+ def get_flat_data_page(self, index) -> torch.Tensor:
"""
Get a flat data page from the host memory pool.
"""
@@ -146,7 +161,7 @@ class HostKVCache(abc.ABC):
"""
raise NotImplementedError()
- @synchronized
+ @synchronized()
def clear(self):
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
@@ -157,7 +172,7 @@ class HostKVCache(abc.ABC):
def available_size(self):
return len(self.free_slots)
- @synchronized
+ @synchronized()
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
assert (
need_size % self.page_size == 0
@@ -168,13 +183,92 @@ class HostKVCache(abc.ABC):
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
+ if self.debug:
+ self.mem_state[select_index] = MemoryStateInt.RESERVED
+
return select_index
- @synchronized
+ @synchronized()
def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices])
+ if self.debug:
+ self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices)
+ @synchronized(debug_only=True)
+ def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
+ assert len(indices) > 0, "The indices should not be empty"
+ states = self.mem_state[indices]
+ assert (
+ states == states[0]
+ ).all(), "The memory slots should have the same state {}".format(states)
+ return MemoryStateInt(states[0].item())
+
+ @synchronized(debug_only=True)
+ def is_reserved(self, indices: torch.Tensor) -> bool:
+ return self.get_state(indices) == MemoryStateInt.RESERVED
+
+ @synchronized(debug_only=True)
+ def is_protected(self, indices: torch.Tensor) -> bool:
+ return self.get_state(indices) == MemoryStateInt.PROTECTED
+
+ @synchronized(debug_only=True)
+ def is_synced(self, indices: torch.Tensor) -> bool:
+ return self.get_state(indices) == MemoryStateInt.SYNCED
+
+ @synchronized(debug_only=True)
+ def is_backup(self, indices: torch.Tensor) -> bool:
+ return self.get_state(indices) == MemoryStateInt.BACKUP
+
+ @synchronized(debug_only=True)
+ def update_backup(self, indices: torch.Tensor):
+ if not self.is_synced(indices):
+ raise ValueError(
+ f"The host memory slots should be in SYNCED state before turning into BACKUP. "
+ f"Current state: {self.get_state(indices)}"
+ )
+ self.mem_state[indices] = MemoryStateInt.BACKUP
+
+ @synchronized(debug_only=True)
+ def update_prefetch(self, indices: torch.Tensor):
+ if not self.is_reserved(indices):
+ raise ValueError(
+ f"The host memory slots should be in RESERVED state before turning into BACKUP. "
+ f"Current state: {self.get_state(indices)}"
+ )
+ self.mem_state[indices] = MemoryStateInt.BACKUP
+
+ @synchronized(debug_only=True)
+ def update_synced(self, indices: torch.Tensor):
+ self.mem_state[indices] = MemoryStateInt.SYNCED
+
+ @synchronized(debug_only=True)
+ def protect_write(self, indices: torch.Tensor):
+ if not self.is_reserved(indices):
+ raise ValueError(
+ f"The host memory slots should be RESERVED before write operations. "
+ f"Current state: {self.get_state(indices)}"
+ )
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
+
+ @synchronized(debug_only=True)
+ def protect_load(self, indices: torch.Tensor):
+ if not self.is_backup(indices):
+ raise ValueError(
+ f"The host memory slots should be in BACKUP state before load operations. "
+ f"Current state: {self.get_state(indices)}"
+ )
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
+
+ @synchronized(debug_only=True)
+ def complete_io(self, indices: torch.Tensor):
+ if not self.is_protected(indices):
+ raise ValueError(
+ f"The host memory slots should be PROTECTED during I/O operations. "
+ f"Current state: {self.get_state(indices)}"
+ )
+ self.mem_state[indices] = MemoryStateInt.SYNCED
+
class MHATokenToKVPoolHost(HostKVCache):
device_pool: MHATokenToKVPool
@@ -367,19 +461,16 @@ class MHATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
- def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
+ def get_flat_data_page(self, index) -> torch.Tensor:
if self.layout == "layer_first":
- data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
+ return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
elif self.layout == "page_first":
- data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
+ return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
elif self.layout == "page_first_direct":
real_index = index // self.page_size
- data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
+ return self.kv_buffer[:, real_index : real_index + 1, :, :, :, :].flatten()
else:
raise ValueError(f"Unsupported layout: {self.layout}")
- if flat:
- data_page = data_page.flatten()
- return data_page
def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros(
@@ -416,12 +507,9 @@ class MHATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
- def get_page_buffer_meta(self, indices):
- """ "
- meta data for zero copy
- """
- assert len(indices) % self.page_size == 0
+ def get_buffer_meta(self, keys, indices, local_rank):
ptr_list = []
+ key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
indices = indices.tolist()
v_offset = (
@@ -431,52 +519,48 @@ class MHATokenToKVPoolHost(HostKVCache):
* self.head_dim
* self.dtype.itemsize
)
- if self.layout == "layer_first":
- for index in range(0, len(indices), self.page_size):
- for layer_id in range(self.layer_num):
- k_ptr = (
- kv_buffer_data_ptr
- + indices[index]
- * self.head_num
- * self.head_dim
- * self.dtype.itemsize
- + layer_id
- * self.size
- * self.head_num
- * self.head_dim
- * self.dtype.itemsize
- )
- v_ptr = k_ptr + v_offset
- ptr_list.append(k_ptr)
- ptr_list.append(v_ptr)
- element_size = (
- self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
- )
- element_size_list = [element_size] * len(ptr_list)
- elif self.layout in ["page_first", "page_first_direct"]:
- for index in range(0, len(indices), self.page_size):
- k_ptr = (
- kv_buffer_data_ptr
- + indices[index]
- * self.layer_num
- * self.head_num
- * self.head_dim
- * self.dtype.itemsize
- )
- v_ptr = k_ptr + v_offset
- ptr_list.append(k_ptr)
- ptr_list.append(v_ptr)
- element_size = (
- self.layer_num
- * self.dtype.itemsize
- * self.page_size
+ for index in range(0, len(indices), self.page_size):
+ k_ptr = (
+ kv_buffer_data_ptr
+ + indices[index]
+ * self.layer_num
* self.head_num
* self.head_dim
+ * self.dtype.itemsize
)
- element_size_list = [element_size] * len(ptr_list)
- else:
- raise ValueError(f"Unsupported layout: {self.layout}")
- return ptr_list, element_size_list
+ v_ptr = k_ptr + v_offset
+ ptr_list.append(k_ptr)
+ ptr_list.append(v_ptr)
+ key_ = keys[index // self.page_size]
+ key_list.append(f"{key_}_{local_rank}_k")
+ key_list.append(f"{key_}_{local_rank}_v")
+ element_size = (
+ self.layer_num
+ * self.dtype.itemsize
+ * self.page_size
+ * self.head_num
+ * self.head_dim
+ )
+ element_size_list = [element_size] * len(key_list)
+ return key_list, ptr_list, element_size_list
+
+ def get_buffer_with_hash(self, keys, indices=None):
+ assert self.layout == "page_first"
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
+
+ key_list = []
+ buf_list = []
+
+ for i in range(len(keys)):
+ key = keys[i]
+ key_list.append(f"{key}-k")
+ key_list.append(f"{key}-v")
+ if indices is not None:
+ index = indices[i * self.page_size]
+ buf_list.append(self.k_buffer[index : index + self.page_size])
+ buf_list.append(self.v_buffer[index : index + self.page_size])
+
+ return key_list, buf_list, 2
class MLATokenToKVPoolHost(HostKVCache):
@@ -652,19 +736,16 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
- def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
+ def get_flat_data_page(self, index) -> torch.Tensor:
if self.layout == "layer_first":
- data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
+ return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
elif self.layout == "page_first":
- data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
+ return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
elif self.layout == "page_first_direct":
real_index = index // self.page_size
- data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
+ return self.kv_buffer[real_index : real_index + 1, :, :, :, :].flatten()
else:
raise ValueError(f"Unsupported layout: {self.layout}")
- if flat:
- data_page = data_page.flatten()
- return data_page
def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros(
@@ -706,51 +787,40 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
- def get_page_buffer_meta(self, indices):
- """ "
- meta data for zero copy
- """
- assert len(indices) % self.page_size == 0
+ def get_buffer_meta(self, keys, indices, local_rank):
ptr_list = []
+ key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
indices = indices.tolist()
- if self.layout == "layer_first":
- for index in range(0, len(indices), self.page_size):
- for layer_id in range(self.layer_num):
- k_ptr = (
- kv_buffer_data_ptr
- + indices[index]
- * (self.kv_lora_rank + self.qk_rope_head_dim)
- * self.dtype.itemsize
- + layer_id
- * self.size
- * (self.kv_lora_rank + self.qk_rope_head_dim)
- * self.dtype.itemsize
- )
- ptr_list.append(k_ptr)
- element_size = (
- self.dtype.itemsize
- * self.page_size
+ for index in range(0, len(indices), self.page_size):
+ k_ptr = (
+ kv_buffer_data_ptr
+ + indices[index]
+ * self.layer_num
* (self.kv_lora_rank + self.qk_rope_head_dim)
- )
- element_size_list = [element_size] * len(ptr_list)
- elif self.layout in ["page_first", "page_first_direct"]:
- for index in range(0, len(indices), self.page_size):
- k_ptr = (
- kv_buffer_data_ptr
- + indices[index]
- * self.layer_num
- * (self.kv_lora_rank + self.qk_rope_head_dim)
- * self.dtype.itemsize
- )
- ptr_list.append(k_ptr)
- element_size = (
- self.layer_num
* self.dtype.itemsize
- * self.page_size
- * (self.kv_lora_rank + self.qk_rope_head_dim)
)
- element_size_list = [element_size] * len(ptr_list)
- else:
- raise ValueError(f"Unsupported layout: {self.layout}")
- return ptr_list, element_size_list
+ ptr_list.append(k_ptr)
+ key_ = keys[index // self.page_size]
+ key_list.append(f"{key_}_k")
+ element_size = (
+ self.layer_num
+ * self.dtype.itemsize
+ * self.page_size
+ * (self.kv_lora_rank + self.qk_rope_head_dim)
+ )
+ element_size_list = [element_size] * len(key_list)
+ return key_list, ptr_list, element_size_list
+
+ def get_buffer_with_hash(self, keys, indices=None):
+ assert self.layout == "page_first"
+ assert indices is None or (len(keys) == (len(indices) // self.page_size))
+
+ buf_list = []
+
+ if indices is not None:
+ for i in range(len(keys)):
+ index = indices[i * self.page_size]
+ buf_list.append(self.kv_buffer[index : index + self.page_size])
+
+ return keys, buf_list, 1
diff --git a/python/sglang/srt/mem_cache/storage/__init__.py b/python/sglang/srt/mem_cache/storage/__init__.py
deleted file mode 100644
index 34ac35508..000000000
--- a/python/sglang/srt/mem_cache/storage/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to SGLang project
-
-"""Storage backend module for SGLang HiCache."""
-
-from .backend_factory import StorageBackendFactory
-
-__all__ = [
- "StorageBackendFactory",
-]
diff --git a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md
deleted file mode 100644
index 16941967f..000000000
--- a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/README.md
+++ /dev/null
@@ -1,37 +0,0 @@
-# AIBrix KVCache as L3 KV Cache
-This document provides brief instructions for setting up a AIBrixKVCache storage backend + AIBrixKVCache + SGLang runtime environment from scratch, describing how to utilize AIBrixKVCache as the L3 KV cache for SGLang.
-The process consists of three main steps:
-
-## Step1:Install AIbrix KVCache
-Refer to the [AIBrix KVCache documentation](https://github.com/vllm-project/aibrix/blob/main/python/aibrix_kvcache/README.md) to install AIBrix KVCache.
-
-## Step2: Deploy AIBrix Distributed KVCache Storage
-
-AIBrix KVCache currently supports multiple distributed KVCache backends, including ByteDance's open-source Infinistore and the not-yet-open source PrisKV incubated by ByteDance's PrisDB & IAAS & DMI team.
-
-For the Infinistore installation process, please refer to [this link](https://github.com/bytedance/InfiniStore).
-
-PrisKV for AIBrix KVCache is currently in the open-source preparation stage, and no public documentation is available yet.
-
-
-## Step3: Deploy Model Serving
-
-For information on configuring a distributed KVCache backend for AIBrixKVCache, please refer to [this link](https://aibrix.readthedocs.io/latest/designs/aibrix-kvcache-offloading-framework.html)
-
-Using PrisKV as an example, the startup command is as follows:
-```bash
-export AIBRIX_KV_CACHE_OL_L1_CACHE_ENABLED="0"
-export AIBRIX_KV_CACHE_OL_L2_CACHE_BACKEND="PRIS"
-export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR="127.0.0.1"
-export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT="6379"
-export AIBRIX_KV_CACHE_OL_PRIS_PASSWORD="kvcache-redis"
-MODEL_LENGTH=32768&&NCCL_MIN_NCHANNELS=24&&NCCL_IB_QPS_PER_CONNECTION=8&&NCCL_DEBUG=INFO \
-python3 -m sglang.launch_server \
- --model-path /code/models/Qwen3-32B \
- --host 0.0.0.0 --port 8080 \
- --enable-hierarchical-cache \
- --hicache-storage-backend aibrix \
- --page-size 16 \
- --hicache-write-policy write_back \
- --enable-metrics --hicache-ratio=2
-```
diff --git a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py
deleted file mode 100644
index 59aacc11d..000000000
--- a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import logging
-from typing import Any, List, Optional
-
-import torch
-from aibrix_kvcache import (
- BaseKVCacheManager,
- BlockHashes,
- KVCacheBlockLayout,
- KVCacheBlockSpec,
- KVCacheConfig,
- KVCacheTensorSpec,
- ModelSpec,
-)
-from aibrix_kvcache.common.absl_logging import log_every_n_seconds
-
-from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
-from sglang.srt.mem_cache.memory_pool_host import HostKVCache
-
-logger = logging.getLogger(__name__)
-
-
-class AibrixKVCacheStorage(HiCacheStorage):
- def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache):
- if storage_config is not None:
- self.is_mla_backend = storage_config.is_mla_model
- self.local_rank = storage_config.tp_rank
- else:
- self.is_mla_backend = False
- self.local_rank = 0
- kv_cache = mem_pool.device_pool
- self.page_size = mem_pool.page_size
- self.kv_cache_dtype = kv_cache.dtype
- self.layer_num = kv_cache.layer_num
- self.kv_head_ids = [
- self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num)
- ]
- if not self.is_mla_backend:
- self.layer_ids = range(
- kv_cache.start_layer, kv_cache.end_layer
- ) # for pipeline parallel
-
- self.block_spec = KVCacheBlockSpec(
- block_ntokens=self.page_size,
- block_dtype=self.kv_cache_dtype,
- block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD),
- tensor_spec=KVCacheTensorSpec(
- heads=self.kv_head_ids,
- layers=self.layer_ids,
- head_size=kv_cache.head_dim,
- ),
- )
- logger.info(self.block_spec)
- config = KVCacheConfig(
- block_spec=self.block_spec, model_spec=ModelSpec(102400)
- )
- self.kv_cache_manager = BaseKVCacheManager(config)
- else:
- raise NotImplementedError(
- "MLA is not supported by AibrixKVCacheStorage yet."
- )
-
- def _aibrix_kvcache_metrics_report(self):
- self.kv_cache_manager.metrics.summary()
- self.kv_cache_manager.metrics.reset()
-
- def batch_get(
- self,
- keys: List[str],
- target_locations: List[torch.Tensor],
- target_sizes: Optional[Any] = None,
- ) -> List[torch.Tensor | None]:
- block_hash = BlockHashes(keys, self.page_size)
- status = self.kv_cache_manager.acquire(None, block_hash)
- log_every_n_seconds(
- logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1
- )
- if status.is_ok():
- num_fetched_tokens, handle = status.value
- kv_blocks = handle.to_tensors()
- assert len(kv_blocks) == len(target_locations)
- for i in range(len(kv_blocks)):
- assert (
- target_locations[i].nbytes == kv_blocks[i].nbytes
- ), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}"
- target_locations[i].copy_(kv_blocks[i].flatten())
- handle.release()
- return target_locations
-
- return [None] * len(keys)
-
- def get(
- self,
- key: str,
- target_location: Optional[Any] = None,
- target_size: Optional[Any] = None,
- ) -> torch.Tensor | None:
- return self.batch_get([key], [target_location], [target_size])[0]
-
- def batch_set(
- self,
- keys: List[str],
- values: Optional[Any] = None,
- target_locations: Optional[Any] = None,
- target_sizes: Optional[Any] = None,
- ) -> bool:
- block_hash = BlockHashes(keys, self.page_size)
- status = self.kv_cache_manager.allocate_for(None, block_hash)
- if not status.is_ok():
- logger.warning(
- f"aibrix_kvcache set allocate failed, error_code {status.error_code}"
- )
- return False
- handle = status.value
- tensors = handle.to_tensors()
- if len(tensors) != len(values):
- logger.warning("aibrix_kvcache set allocate not enough")
- return False
- for i in range(len(tensors)):
- assert (
- tensors[i].nbytes == values[i].nbytes
- ), f"{tensors[i].nbytes}, {values[i].nbytes}"
- tensors[i].reshape(values[i].shape).copy_(values[i]).reshape(
- tensors[i].shape
- )
- status = self.kv_cache_manager.put(None, block_hash, handle)
- if not status.is_ok():
- logger.info(
- f"AIBrix KVCache Storage set failed, error_code {status.error_code}"
- )
- return False
- completed = status.value
- return completed == len(keys) * self.page_size
-
- def set(
- self,
- key: str,
- value: Optional[Any] = None,
- target_location: Optional[Any] = None,
- target_size: Optional[Any] = None,
- ) -> bool:
- return self.batch_set([key], [value], [target_location], [target_size])
-
- def batch_exists(self, keys: List[str]) -> int:
- block_hash = BlockHashes(keys, self.page_size)
- status = self.kv_cache_manager.exists(None, block_hash)
- if status.is_ok():
- return status.value // self.page_size
- return 0
-
- def exists(self, key: str) -> bool | dict:
- return self.batch_exists([key]) > 0
diff --git a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py b/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py
deleted file mode 100644
index 2e54e9816..000000000
--- a/python/sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import logging
-import os
-
-import torch
-import torch.distributed
-from aibrix_kvcache import (
- BaseKVCacheManager,
- GroupAwareKVCacheManager,
- KVCacheBlockLayout,
- KVCacheBlockSpec,
- KVCacheConfig,
- KVCacheMetrics,
- KVCacheTensorSpec,
- ModelSpec,
- TokenListView,
-)
-from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
-from aibrix_kvcache_storage import AibrixKVCacheStorage
-from torch.distributed import Backend, ProcessGroup
-
-from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
-from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
-from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
-
-logging.basicConfig(
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
-)
-
-logger = logging.getLogger(__name__)
-
-
-def setup():
- os.environ["RANK"] = "0"
- os.environ["WORLD_SIZE"] = "1"
- os.environ["MASTER_ADDR"] = "127.0.0.1"
- os.environ["MASTER_PORT"] = "63886"
-
-
-class AIBrixKVCacheStorageTest:
- def test_with_page_size(self):
- config = HiCacheStorageConfig(
- tp_rank=0,
- tp_size=1,
- is_mla_model=False,
- is_page_first_layout=True,
- model_name="test",
- )
- for page_size in range(1, 3):
- logger.info(f"page_size: {page_size}")
- batch_size = 2
- head_num = 1
- layer_num = 64
- head_dim = 128
- kv_cache = MHATokenToKVPool(
- 1024,
- page_size,
- torch.float16,
- head_num,
- head_dim,
- layer_num,
- "cpu",
- False,
- 0,
- layer_num,
- )
- mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first")
- query_length = batch_size * 2
- partial = batch_size
- self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool)
- target_shape = (2, layer_num, page_size, head_num, head_dim)
- rand_tensor = [
- torch.rand(target_shape, dtype=torch.float16)
- for _ in range(query_length)
- ]
- keys = ["hash" + str(i) for i in range(query_length)]
- partial_keys = keys[batch_size:query_length]
- assert self.aibrix_kvcache.batch_exists(keys) == 0
- assert self.aibrix_kvcache.batch_set(keys, rand_tensor)
- get_tensor = [
- torch.rand(target_shape, dtype=torch.float16).flatten()
- for _ in range(query_length)
- ]
- self.aibrix_kvcache.batch_get(keys, get_tensor)
- for i in range(query_length):
- assert torch.equal(get_tensor[i], rand_tensor[i].flatten())
- ret = self.aibrix_kvcache.batch_exists(keys)
- assert self.aibrix_kvcache.batch_exists(keys) == query_length
- assert self.aibrix_kvcache.batch_exists(partial_keys) == partial
- partial_get_tensor = [
- torch.rand(target_shape, dtype=torch.float16).flatten()
- for _ in range(partial)
- ]
- self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor)
- for i in range(partial):
- assert torch.equal(
- partial_get_tensor[i], rand_tensor[i + partial].flatten()
- )
- log_every_n_seconds(
- logger,
- logging.INFO,
- self.aibrix_kvcache.kv_cache_manager.metrics.summary(),
- 1,
- )
-
-
-if __name__ == "__main__":
- setup()
- test = AIBrixKVCacheStorageTest()
- test.test_with_page_size()
diff --git a/python/sglang/srt/mem_cache/storage/backend_factory.py b/python/sglang/srt/mem_cache/storage/backend_factory.py
deleted file mode 100644
index a141afb21..000000000
--- a/python/sglang/srt/mem_cache/storage/backend_factory.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to SGLang project
-
-import importlib
-import logging
-from typing import TYPE_CHECKING, Any, Dict
-
-from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
-
-if TYPE_CHECKING:
- pass
-
-logger = logging.getLogger(__name__)
-
-
-class StorageBackendFactory:
- """Factory for creating storage backend instances with support for dynamic loading."""
-
- _registry: Dict[str, Dict[str, Any]] = {}
-
- @staticmethod
- def _load_backend_class(
- module_path: str, class_name: str, backend_name: str
- ) -> type[HiCacheStorage]:
- """Load and validate a backend class from module path."""
- try:
- module = importlib.import_module(module_path)
- backend_class = getattr(module, class_name)
- if not issubclass(backend_class, HiCacheStorage):
- raise TypeError(
- f"Backend class {class_name} must inherit from HiCacheStorage"
- )
- return backend_class
- except ImportError as e:
- raise ImportError(
- f"Failed to import backend '{backend_name}' from '{module_path}': {e}"
- ) from e
- except AttributeError as e:
- raise AttributeError(
- f"Class '{class_name}' not found in module '{module_path}': {e}"
- ) from e
-
- @classmethod
- def register_backend(cls, name: str, module_path: str, class_name: str) -> None:
- """Register a storage backend with lazy loading.
-
- Args:
- name: Backend identifier
- module_path: Python module path containing the backend class
- class_name: Name of the backend class
- """
- if name in cls._registry:
- logger.warning(f"Backend '{name}' is already registered, overwriting")
-
- def loader() -> type[HiCacheStorage]:
- """Lazy loader function to import the backend class."""
- return cls._load_backend_class(module_path, class_name, name)
-
- cls._registry[name] = {
- "loader": loader,
- "module_path": module_path,
- "class_name": class_name,
- }
-
- @classmethod
- def create_backend(
- cls,
- backend_name: str,
- storage_config: HiCacheStorageConfig,
- mem_pool_host: Any,
- **kwargs,
- ) -> HiCacheStorage:
- """Create a storage backend instance.
- Args:
- backend_name: Name of the backend to create
- storage_config: Storage configuration
- mem_pool_host: Memory pool host object
- **kwargs: Additional arguments passed to external backends
- Returns:
- Initialized storage backend instance
- Raises:
- ValueError: If backend is not registered and cannot be dynamically loaded
- ImportError: If backend module cannot be imported
- Exception: If backend initialization fails
- """
- # First check if backend is already registered
- if backend_name in cls._registry:
- registry_entry = cls._registry[backend_name]
- backend_class = registry_entry["loader"]()
- logger.info(
- f"Creating storage backend '{backend_name}' "
- f"({registry_entry['module_path']}.{registry_entry['class_name']})"
- )
- return cls._create_builtin_backend(
- backend_name, backend_class, storage_config, mem_pool_host
- )
-
- # Try to dynamically load backend from extra_config
- if backend_name == "dynamic" and storage_config.extra_config is not None:
- backend_config = storage_config.extra_config
- return cls._create_dynamic_backend(
- backend_config, storage_config, mem_pool_host, **kwargs
- )
-
- # Backend not found
- available_backends = list(cls._registry.keys())
-
- raise ValueError(
- f"Unknown storage backend '{backend_name}'. "
- f"Registered backends: {available_backends}. "
- )
-
- @classmethod
- def _create_dynamic_backend(
- cls,
- backend_config: Dict[str, Any],
- storage_config: HiCacheStorageConfig,
- mem_pool_host: Any,
- **kwargs,
- ) -> HiCacheStorage:
- """Create a backend dynamically from configuration."""
- required_fields = ["backend_name", "module_path", "class_name"]
- for field in required_fields:
- if field not in backend_config:
- raise ValueError(
- f"Missing required field '{field}' in backend config for 'dynamic' backend"
- )
-
- backend_name = backend_config["backend_name"]
- module_path = backend_config["module_path"]
- class_name = backend_config["class_name"]
-
- try:
- # Import the backend class
- backend_class = cls._load_backend_class(
- module_path, class_name, backend_name
- )
-
- logger.info(
- f"Creating dynamic storage backend '{backend_name}' "
- f"({module_path}.{class_name})"
- )
-
- # Create the backend instance with storage_config
- return backend_class(storage_config, kwargs)
- except Exception as e:
- logger.error(
- f"Failed to create dynamic storage backend '{backend_name}': {e}"
- )
- raise
-
- @classmethod
- def _create_builtin_backend(
- cls,
- backend_name: str,
- backend_class: type[HiCacheStorage],
- storage_config: HiCacheStorageConfig,
- mem_pool_host: Any,
- ) -> HiCacheStorage:
- """Create built-in backend with original initialization logic."""
- if backend_name == "file":
- return backend_class(storage_config)
- elif backend_name == "nixl":
- return backend_class()
- elif backend_name == "mooncake":
- backend = backend_class(storage_config)
- return backend
- elif backend_name == "aibrix":
- backend = backend_class(storage_config, mem_pool_host)
- return backend
- elif backend_name == "hf3fs":
- # Calculate bytes_per_page based on memory pool layout
- if mem_pool_host.layout == "page_first":
- bytes_per_page = (
- mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
- )
- elif mem_pool_host.layout == "layer_first":
- bytes_per_page = (
- mem_pool_host.get_size_per_token() * mem_pool_host.page_size
- )
-
- dtype = mem_pool_host.dtype
- return backend_class.from_env_config(bytes_per_page, dtype, storage_config)
- else:
- raise ValueError(f"Unknown built-in backend: {backend_name}")
-
-
-# Register built-in storage backends
-StorageBackendFactory.register_backend(
- "file", "sglang.srt.mem_cache.hicache_storage", "HiCacheFile"
-)
-
-StorageBackendFactory.register_backend(
- "nixl",
- "sglang.srt.mem_cache.storage.nixl.hicache_nixl",
- "HiCacheNixl",
-)
-
-StorageBackendFactory.register_backend(
- "mooncake",
- "sglang.srt.mem_cache.storage.mooncake_store.mooncake_store",
- "MooncakeStore",
-)
-
-StorageBackendFactory.register_backend(
- "hf3fs",
- "sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs",
- "HiCacheHF3FS",
-)
-
-StorageBackendFactory.register_backend(
- "aibrix",
- "sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage",
- "AibrixKVCacheStorage",
-)
diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
index 2a159e493..9595e7204 100644
--- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
+++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
@@ -12,12 +12,7 @@ from typing import Any, List, Optional, Tuple
import torch
-from sglang.srt.mem_cache.hicache_storage import (
- HiCacheStorage,
- HiCacheStorageConfig,
- HiCacheStorageExtraInfo,
-)
-from sglang.srt.mem_cache.memory_pool_host import HostKVCache
+from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
from sglang.srt.metrics.collector import StorageMetrics
@@ -183,14 +178,11 @@ class HiCacheHF3FS(HiCacheStorage):
self.skip_backup = True
self.rank = 0
- self.is_zero_copy = False
-
logger.info(
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
f"file_path={self.file_path}, "
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
- f"num_pages={self.num_pages}, "
- f"is_mla_model={self.is_mla_model}"
+ f"num_pages={self.num_pages}"
)
self.ac = AtomicCounter(self.numjobs)
@@ -331,12 +323,25 @@ class HiCacheHF3FS(HiCacheStorage):
use_mock_client=use_mock_client,
)
+ def get(
+ self,
+ key: str,
+ target_location: Optional[Any] = None,
+ target_sizes: Optional[Any] = None,
+ ) -> torch.Tensor | None:
+ return self.batch_get(
+ [key],
+ [target_location] if target_location is not None else None,
+ [target_sizes] if target_sizes is not None else None,
+ )[0]
+
@synchronized()
- def _batch_get(
+ def batch_get(
self,
keys: List[str],
- values: List[torch.Tensor],
- ) -> List[bool]:
+ target_locations: Optional[Any] = None,
+ target_sizes: Optional[Any] = None,
+ ) -> List[torch.Tensor | None]:
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
batch_indices, file_offsets = [], []
@@ -345,9 +350,15 @@ class HiCacheHF3FS(HiCacheStorage):
batch_indices.append(i)
file_offsets.append(page_index * self.bytes_per_page)
- for target_location in values:
- assert target_location.is_contiguous()
- file_results = values
+ if target_locations is not None:
+ for target_location in target_locations:
+ assert target_location.is_contiguous()
+ file_results = target_locations
+ else:
+ file_results = [
+ torch.empty(self.numel, dtype=self.dtype)
+ for _ in range(len(batch_indices))
+ ]
start_time = time.perf_counter()
@@ -368,10 +379,12 @@ class HiCacheHF3FS(HiCacheStorage):
ionum / (end_time - start_time) * self.gb_per_page
)
- results = [False] * len(keys)
- for batch_index, read_result in zip(batch_indices, read_results):
+ results = [None] * len(keys)
+ for batch_index, file_result, read_result in zip(
+ batch_indices, file_results, read_results
+ ):
if read_result == self.bytes_per_page:
- results[batch_index] = True
+ results[batch_index] = file_result
else:
logger.error(
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
@@ -379,12 +392,28 @@ class HiCacheHF3FS(HiCacheStorage):
return results
+ def set(
+ self,
+ key: str,
+ value: Optional[Any] = None,
+ target_location: Optional[Any] = None,
+ target_sizes: Optional[Any] = None,
+ ) -> bool:
+ return self.batch_set(
+ [key],
+ [value] if value is not None else None,
+ [target_location] if target_location is not None else None,
+ [target_sizes] if target_sizes is not None else None,
+ )
+
@synchronized()
- def _batch_set(
+ def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
- ) -> List[bool]:
+ target_locations: Optional[Any] = None,
+ target_sizes: Optional[Any] = None,
+ ) -> bool:
# In MLA backend, only one rank needs to backup the KV cache
if self.skip_backup:
return True
@@ -445,7 +474,7 @@ class HiCacheHF3FS(HiCacheStorage):
self.rank, written_keys_to_confirm, pages_to_release
)
- return results
+ return all(results)
def delete(self, key: str) -> None:
self.metadata_client.delete_keys(self.rank, [key])
@@ -455,25 +484,21 @@ class HiCacheHF3FS(HiCacheStorage):
return result[0] if result else False
def batch_exists(self, keys: List[str]) -> int:
- factor = 1
- if self.is_zero_copy and not self.is_mla_model:
- keys = self._get_mha_zero_copy_keys(keys)
- factor = 2
-
results = self.metadata_client.exists(self.rank, keys)
+ for i in range(len(keys)):
+ if not results[i]:
+ return i
- i = 0
- while i < len(keys) and results[i]:
- i += 1
+ return len(keys)
- return i // factor
-
- def clear(self) -> None:
+ def clear(self) -> bool:
try:
self.metadata_client.clear(self.rank)
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
+ return True
except Exception as e:
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
+ return False
def close(self) -> None:
try:
@@ -496,139 +521,3 @@ class HiCacheHF3FS(HiCacheStorage):
self.prefetch_bandwidth.clear()
self.backup_bandwidth.clear()
return storage_metrics
-
- def register_mem_pool_host(self, mem_pool_host: HostKVCache):
- super().register_mem_pool_host(mem_pool_host)
- self.is_zero_copy = self.mem_pool_host.layout == "page_first"
- logger.info(f"{self.is_zero_copy=}")
-
- def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
- _keys = []
- for k in keys:
- _keys.append(f"{k}-k")
- _keys.append(f"{k}-v")
- return _keys
-
- def _get_mha_zero_copy_values(
- self, values: List[torch.Tensor]
- ) -> List[torch.Tensor]:
- _values = []
- for value in values:
- _values.append(value[0])
- _values.append(value[1])
- return _values
-
- def _batch_get_preprocess(self, keys, host_indices):
- page_num = len(host_indices) // self.mem_pool_host.page_size
- # host_indices to kv_buffer
- flat = not self.is_zero_copy
- values = (
- [
- self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
- for i in range(page_num)
- ]
- if self.is_zero_copy
- else [
- self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num)
- ]
- )
-
- if self.is_zero_copy and not self.is_mla_model:
- keys = self._get_mha_zero_copy_keys(keys)
- values = self._get_mha_zero_copy_values(values)
-
- return keys, values
-
- def _batch_get_postprocess(self, host_indices, values, results):
- page_num = len(host_indices) // self.mem_pool_host.page_size
-
- if self.is_zero_copy:
- if not self.is_mla_model:
- results = [
- (results[2 * i] and results[2 * i + 1]) for i in range(page_num)
- ]
- results = results[:page_num]
- return results
-
- for i in range(page_num):
- if not results[i]:
- break
- self.mem_pool_host.set_from_flat_data_page(
- host_indices[i * self.mem_pool_host.page_size], values[i]
- )
-
- return results
-
- def batch_get_v1(
- self,
- keys: List[str],
- host_indices: torch.Tensor,
- extra_info: Optional[HiCacheStorageExtraInfo] = None,
- ) -> List[bool]:
- keys, values = self._batch_get_preprocess(keys, host_indices)
- results = self._batch_get(keys, values)
- return self._batch_get_postprocess(host_indices, values, results)
-
- def _batch_set_preprocess(self, keys, host_indices):
- page_num = len(host_indices) // self.mem_pool_host.page_size
- # host_indices to kv_buffer
- flat = not self.is_zero_copy
- values = [
- self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
- for i in range(page_num)
- ]
-
- if self.is_zero_copy and not self.is_mla_model:
- keys = self._get_mha_zero_copy_keys(keys)
- values = self._get_mha_zero_copy_values(values)
-
- return keys, values
-
- def batch_set_v1(
- self,
- keys: List[str],
- host_indices: torch.Tensor,
- extra_info: Optional[HiCacheStorageExtraInfo] = None,
- ) -> List[bool]:
- len_keys = len(keys)
- keys, values = self._batch_set_preprocess(keys, host_indices)
- results = self._batch_set(keys, values)
- return results
-
- # Deprecated
- def get(
- self,
- key: str,
- target_location: Optional[Any] = None,
- target_sizes: Optional[Any] = None,
- ) -> torch.Tensor | None:
- pass
-
- # Deprecated
- def batch_get(
- self,
- keys: List[str],
- target_locations: Optional[Any] = None,
- target_sizes: Optional[Any] = None,
- ) -> List[torch.Tensor | None] | int:
- pass
-
- # Deprecated
- def set(
- self,
- key: str,
- value: Optional[Any] = None,
- target_location: Optional[Any] = None,
- target_sizes: Optional[Any] = None,
- ) -> bool:
- pass
-
- # Deprecated
- def batch_set(
- self,
- keys: List[str],
- values: Optional[Any] = None,
- target_locations: Optional[Any] = None,
- target_sizes: Optional[Any] = None,
- ) -> bool:
- pass
diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py
index 0b9db07f7..2704581e6 100644
--- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py
+++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py
@@ -7,12 +7,7 @@ from typing import Any, List, Optional
import torch
-from sglang.srt.mem_cache.hicache_storage import (
- HiCacheStorage,
- HiCacheStorageConfig,
- HiCacheStorageExtraInfo,
-)
-from sglang.srt.mem_cache.memory_pool_host import HostKVCache
+from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
@@ -188,13 +183,7 @@ class MooncakeStore(HiCacheStorage):
assert self.store.is_exist(warmup_key) == 1
assert self.store.get(warmup_key) == warmup_value
- def register_mem_pool_host(self, mem_pool_host: HostKVCache):
- super().register_mem_pool_host(mem_pool_host)
- assert self.mem_pool_host.layout in [
- "page_first",
- "page_first_direct",
- ], "mooncake store storage backend only support page first or page first direct layout"
- buffer = self.mem_pool_host.kv_buffer
+ def register_buffer(self, buffer: torch.Tensor) -> None:
try:
buffer_ptr = buffer.data_ptr()
buffer_size = buffer.numel() * buffer.element_size()
@@ -205,97 +194,6 @@ class MooncakeStore(HiCacheStorage):
logger.error("Failed to register buffer to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Register Buffer Error.") from err
- def _get_mha_buffer_meta(self, keys, indices):
- ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
- key_list = []
- for key_ in keys:
- key_list.append(f"{key_}_{self.local_rank}_k")
- key_list.append(f"{key_}_{self.local_rank}_v")
- assert len(key_list) == len(ptr_list)
- return key_list, ptr_list, element_size_list
-
- def _get_mla_buffer_meta(self, keys, indices):
- ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
- key_list = []
- for key_ in keys:
- key_list.append(f"{key_}_k")
- assert len(key_list) == len(ptr_list)
- return key_list, ptr_list, element_size_list
-
- def _batch_preprocess(self, keys, host_indices):
- assert len(keys) > 0
- assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
- if self.is_mla_backend:
- return self._get_mla_buffer_meta(keys, host_indices)
- else:
- return self._get_mha_buffer_meta(keys, host_indices)
-
- def _batch_postprocess(self, results: List[int], is_set_operate=False):
- """
- refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
- for batch_get_into, results is Vector of integers,
- where each element is the number of bytes read on success, or a negative value on error
- for batch_put_from, results is Vector of integers,
- where each element is 0 on success, or a negative value on error
- """
- if self.is_mla_backend:
- return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
- else:
- kv_pairs = zip(results[::2], results[1::2])
- return [
- (
- (k_res == 0 and v_res == 0)
- if is_set_operate
- else (k_res > 0 and v_res > 0)
- )
- for k_res, v_res in kv_pairs
- ]
-
- def batch_get_v1(
- self,
- keys: List[str],
- host_indices: torch.Tensor,
- extra_info: Optional[HiCacheStorageExtraInfo] = None,
- ) -> List[bool]:
- key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
- get_results = self._get_batch_zero_copy_impl(
- key_strs, buffer_ptrs, buffer_sizes
- )
- return self._batch_postprocess(get_results, is_set_operate=False)
-
- def batch_set_v1(
- self,
- keys: List[str],
- host_indices: torch.Tensor,
- extra_info: Optional[HiCacheStorageExtraInfo] = None,
- ) -> List[bool]:
- key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
- exist_result = self._batch_exist(key_strs)
-
- set_keys = []
- set_buffer_ptrs = []
- set_buffer_sizes = []
- set_indices = []
- set_results = [-1] * len(key_strs)
- for i in range(len(key_strs)):
- if exist_result[i] != 1:
- set_keys.append(key_strs[i])
- set_buffer_ptrs.append(buffer_ptrs[i])
- set_buffer_sizes.append(buffer_sizes[i])
- set_indices.append(i)
- else:
- set_results[i] = 0
-
- # Only set non-existing keys to storage
- if len(set_keys) > 0:
- put_results = self._put_batch_zero_copy_impl(
- set_keys, set_buffer_ptrs, set_buffer_sizes
- )
- for i in range(len(set_indices)):
- set_results[set_indices[i]] = put_results[i]
-
- return self._batch_postprocess(set_results, is_set_operate=True)
-
def set(
self,
key,
diff --git a/python/sglang/srt/metrics/utils.py b/python/sglang/srt/metrics/utils.py
index 4dc498df7..73c0b4e73 100644
--- a/python/sglang/srt/metrics/utils.py
+++ b/python/sglang/srt/metrics/utils.py
@@ -44,7 +44,7 @@ def generate_buckets(
return two_sides_exponential_buckets(float(middle), float(base), int(count))
if rule == "default":
return sorted(set(default_buckets))
- assert rule == "custom"
+ assert rule == "customer"
return sorted(set([float(x) for x in buckets_rule[1:]]))
diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py
index 2ed78ea58..364b3391e 100644
--- a/python/sglang/srt/model_executor/cuda_graph_runner.py
+++ b/python/sglang/srt/model_executor/cuda_graph_runner.py
@@ -167,6 +167,29 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
+ if capture_bs is None:
+ if server_args.speculative_algorithm is None:
+ if server_args.disable_cuda_graph_padding:
+ capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
+ else:
+ capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
+ else:
+ # Since speculative decoding requires more cuda graph memory, we
+ # capture less.
+ capture_bs = (
+ list(range(1, 9))
+ + list(range(10, 33, 2))
+ + list(range(40, 64, 8))
+ + list(range(80, 161, 16))
+ )
+
+ gpu_mem = get_device_memory_capacity()
+ if gpu_mem is not None:
+ if gpu_mem > 90 * 1024: # H200, H20
+ capture_bs += list(range(160, 257, 8))
+ if gpu_mem > 160 * 1000: # B200, MI300
+ capture_bs += list(range(256, 513, 16))
+
if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
@@ -182,6 +205,12 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
+ if server_args.cuda_graph_max_bs:
+ capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
+ if max(capture_bs) < server_args.cuda_graph_max_bs:
+ capture_bs += list(
+ range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
+ )
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
capture_bs = list(sorted(set(capture_bs)))
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
@@ -246,7 +275,7 @@ class CudaGraphRunner:
if (
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
- or model_runner.spec_algorithm.is_ngram()
+ or model_runner.spec_algorithm.is_lookahead()
):
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
@@ -413,12 +442,12 @@ class CudaGraphRunner:
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
)
- is_ngram_supported = (
+ is_lookahead_supported = (
(
forward_batch.batch_size * self.num_tokens_per_bs
== forward_batch.input_ids.numel()
)
- if self.model_runner.spec_algorithm.is_ngram()
+ if self.model_runner.spec_algorithm.is_lookahead()
else True
)
@@ -427,7 +456,7 @@ class CudaGraphRunner:
and is_encoder_lens_supported
and is_tbo_supported
and capture_hidden_mode_matches
- and is_ngram_supported
+ and is_lookahead_supported
)
def capture(self) -> None:
@@ -437,7 +466,6 @@ class CudaGraphRunner:
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
)
- torch.cuda.memory._record_memory_history()
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
@@ -486,8 +514,6 @@ class CudaGraphRunner:
save_gemlite_cache()
if self.enable_profile_cuda_graph:
- torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
- torch.cuda.memory._record_memory_history(enabled=None)
log_message = (
"Sorted by CUDA Time:\n"
+ prof.key_averages(group_by_input_shape=True).table(
@@ -497,7 +523,6 @@ class CudaGraphRunner:
+ prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
- + "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
)
logger.info(log_message)
@@ -518,6 +543,9 @@ class CudaGraphRunner:
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
+ seq_lens_cpu = self.seq_lens_cpu[
+ :bs
+ ] # TODO: Remove this after changing to real indexer
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
if self.is_encoder_decoder:
@@ -588,6 +616,7 @@ class CudaGraphRunner:
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
+ seq_lens_cpu=seq_lens_cpu, # TODO: Remove this after changing to real indexer
next_token_logits_buffer=next_token_logits_buffer,
orig_seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
@@ -842,10 +871,10 @@ class CudaGraphRunner:
seq_lens_cpu=None,
)
- elif self.model_runner.spec_algorithm.is_ngram():
- from sglang.srt.speculative.ngram_utils import NgramVerifyInput
+ elif self.model_runner.spec_algorithm.is_lookahead():
+ from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
- spec_info = NgramVerifyInput(
+ spec_info = LookaheadVerifyInput(
draft_token=None,
tree_mask=self.custom_mask,
positions=None,
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index 017b5863c..a019245c7 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
set_dp_buffer_len,
)
-from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
-from sglang.srt.utils import (
- flatten_nested_list,
- get_compiler_backend,
- is_npu,
- support_triton,
-)
+from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -300,6 +294,7 @@ class ForwardBatch:
# For padding
padded_static_len: int = -1 # -1 if not padded
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
+ num_token_non_padded_cpu: int = None
# For Qwen2-VL
mrope_positions: torch.Tensor = None
@@ -361,6 +356,7 @@ class ForwardBatch:
ret.num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)
+ ret.num_token_non_padded_cpu = len(batch.input_ids)
# For MLP sync
if batch.global_num_tokens is not None:
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 42916486a..6b37aea2f 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -33,7 +33,12 @@ import torch.distributed as dist
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
-from sglang.srt.configs.model_config import AttentionArch, ModelConfig
+from sglang.srt.configs.model_config import (
+ AttentionArch,
+ ModelConfig,
+ get_nsa_index_head_dim,
+ is_deepseek_nsa,
+)
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.connector import ConnectorType
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
@@ -60,10 +65,7 @@ from sglang.srt.eplb.expert_location import (
set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
-from sglang.srt.layers.attention.attention_registry import (
- ATTENTION_BACKENDS,
- attn_backend_wrapper,
-)
+from sglang.srt.layers.attention.attention_registry import ATTENTION_BACKENDS
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
@@ -98,6 +100,7 @@ from sglang.srt.mem_cache.memory_pool import (
HybridReqToTokenPool,
MHATokenToKVPool,
MLATokenToKVPool,
+ NSATokenToKVPool,
ReqToTokenPool,
SWAKVPool,
)
@@ -107,9 +110,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
-from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
- trigger_init_weights_send_group_for_remote_instance_request,
-)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.offloader import (
@@ -118,6 +118,9 @@ from sglang.srt.offloader import (
set_offloader,
)
from sglang.srt.patch_torch import monkey_patch_torch_reductions
+from sglang.srt.remote_instance_weight_loader_utils import (
+ trigger_init_weights_send_group_for_remote_instance_request,
+)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -160,6 +163,7 @@ MLA_ATTENTION_BACKENDS = [
"cutlass_mla",
"trtllm_mla",
"ascend",
+ "nsa",
]
@@ -182,13 +186,6 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__)
-if _is_npu:
- import torch_npu
-
- torch.npu.config.allow_internal_format = True
- torch_npu.npu.set_compile_mode(jit_compile=False)
-
-
class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
@@ -350,6 +347,7 @@ class ModelRunner:
if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache")
self.server_args.disable_radix_cache = True
+ self.server_args.attention_backend = "hybrid_linear_attn"
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
self.server_args.max_mamba_cache_size = (
@@ -745,10 +743,6 @@ class ModelRunner:
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
- tp_rank=self.tp_rank,
- remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
- remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
- remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(
@@ -1484,8 +1478,7 @@ class ModelRunner:
if self.max_total_num_tokens <= 0:
raise RuntimeError(
- f"Not enough memory. Please try to increase --mem-fraction-static. "
- f"Current value: {self.server_args.mem_fraction_static=}"
+ "Not enough memory. Please try to increase --mem-fraction-static."
)
# Initialize req_to_token_pool
@@ -1544,6 +1537,7 @@ class ModelRunner:
assert self.is_draft_worker
# Initialize token_to_kv_pool
+ is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
if self.server_args.attention_backend == "ascend":
if self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
@@ -1552,6 +1546,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
+ index_head_dim=self.model_config.index_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
@@ -1571,7 +1566,22 @@ class ModelRunner:
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
+ elif self.use_mla_backend and is_nsa_model:
+ self.token_to_kv_pool = NSATokenToKVPool(
+ self.max_total_num_tokens,
+ page_size=self.page_size,
+ dtype=self.kv_cache_dtype,
+ kv_lora_rank=self.model_config.kv_lora_rank,
+ qk_rope_head_dim=self.model_config.qk_rope_head_dim,
+ layer_num=self.num_effective_layers,
+ device=self.device,
+ enable_memory_saver=self.server_args.enable_memory_saver,
+ start_layer=self.start_layer,
+ end_layer=self.end_layer,
+ index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
+ )
elif self.use_mla_backend:
+ assert not is_nsa_model
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
@@ -1650,9 +1660,10 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
- if _is_npu and (
- self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
- ):
+ if _is_npu and self.server_args.attention_backend in [
+ "ascend",
+ "hybrid_linear_attn",
+ ]:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
@@ -1765,8 +1776,7 @@ class ModelRunner:
def _get_attention_backend_from_str(self, backend_str: str):
if backend_str not in ATTENTION_BACKENDS:
raise ValueError(f"Invalid attention backend: {backend_str}")
- full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
- return attn_backend_wrapper(self, full_attention_backend)
+ return ATTENTION_BACKENDS[backend_str](self)
def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"
diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py
index d7619b2d7..732f8900f 100644
--- a/python/sglang/srt/model_executor/npu_graph_runner.py
+++ b/python/sglang/srt/model_executor/npu_graph_runner.py
@@ -19,10 +19,8 @@ import logging
import threading
from typing import TYPE_CHECKING, Optional, Union
-import numpy as np
import torch
-from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
logger = logging.getLogger(__name__)
@@ -75,11 +73,16 @@ class NPUGraphRunner(CudaGraphRunner):
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
- seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
- thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
- thread.start()
- self.graphs[self.bs].replay()
- thread.join()
+ if self.model_runner.model_config.index_head_dim is None:
+ seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
+ self.bs - self.raw_bs
+ )
+ thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
+ thread.start()
+ self.graphs[self.bs].replay()
+ thread.join()
+ else:
+ self.graphs[self.bs].replay()
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py
index 12b4575f9..ab9c69fc2 100644
--- a/python/sglang/srt/model_loader/loader.py
+++ b/python/sglang/srt/model_loader/loader.py
@@ -54,9 +54,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
-from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
- trigger_transferring_weights_request,
-)
from sglang.srt.model_loader.utils import (
get_model_architecture,
post_load_weights,
@@ -80,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import (
safetensors_weights_iterator,
set_runai_streamer_env,
)
+from sglang.srt.remote_instance_weight_loader_utils import (
+ trigger_transferring_weights_request,
+)
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
@@ -206,10 +206,7 @@ def _initialize_model(
if _is_npu:
packed_modules_mapping.update(
{
- "visual": {
- "qkv_proj": ["qkv"],
- "gate_up_proj": ["gate_proj", "up_proj"],
- },
+ "visual": {"qkv_proj": ["qkv"]},
"vision_model": {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"proj": ["out_proj"],
@@ -1420,7 +1417,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f"load format {load_config.load_format}"
)
- model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
+ model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
@@ -1442,12 +1439,11 @@ class RemoteInstanceModelLoader(BaseModelLoader):
def load_model_from_remote_instance(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
- load_config = self.load_config
instance_ip = socket.gethostbyname(socket.gethostname())
start_build_group_tic = time.time()
client.build_group(
gpu_id=device_config.gpu_id,
- tp_rank=load_config.tp_rank,
+ tp_rank=model_config.tp_rank,
instance_ip=instance_ip,
)
torch.cuda.synchronize()
@@ -1456,13 +1452,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
)
- if load_config.tp_rank == 0:
+ if model_config.tp_rank == 0:
t = threading.Thread(
target=trigger_transferring_weights_request,
args=(
- load_config.remote_instance_weight_loader_seed_instance_ip,
- load_config.remote_instance_weight_loader_seed_instance_service_port,
- load_config.remote_instance_weight_loader_send_weights_group_ports,
+ model_config.remote_instance_weight_loader_seed_instance_ip,
+ model_config.remote_instance_weight_loader_seed_instance_service_port,
+ model_config.remote_instance_weight_loader_send_weights_group_ports,
instance_ip,
),
)
diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py
index 44297d687..397d9e913 100644
--- a/python/sglang/srt/model_loader/weight_utils.py
+++ b/python/sglang/srt/model_loader/weight_utils.py
@@ -8,6 +8,7 @@ import hashlib
import json
import logging
import os
+import queue
import tempfile
from collections import defaultdict
from typing import (
@@ -37,8 +38,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
-from sglang.srt.utils import find_local_repo_dir, print_warning_once
-from sglang.utils import is_in_ci
+from sglang.srt.utils import print_warning_once
logger = logging.getLogger(__name__)
@@ -236,89 +236,6 @@ def get_quant_config(
return quant_cls.from_config(config)
-def find_local_hf_snapshot_dir(
- model_name_or_path: str,
- cache_dir: Optional[str],
- allow_patterns: List[str],
- revision: Optional[str] = None,
-) -> Optional[str]:
- """If the weights are already local, skip downloading and returns the path."""
- if os.path.isdir(model_name_or_path):
- return None
-
- found_local_snapshot_dir = None
-
- # Check custom cache_dir (if provided)
- if cache_dir:
- try:
- repo_folder = os.path.join(
- cache_dir,
- huggingface_hub.constants.REPO_ID_SEPARATOR.join(
- ["models", *model_name_or_path.split("/")]
- ),
- )
- rev_to_use = revision
- if not rev_to_use:
- ref_main = os.path.join(repo_folder, "refs", "main")
- if os.path.isfile(ref_main):
- with open(ref_main) as f:
- rev_to_use = f.read().strip()
- if rev_to_use:
- rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
- if os.path.isdir(rev_dir):
- found_local_snapshot_dir = rev_dir
- except Exception as e:
- logger.warning(
- "Failed to find local snapshot in custom cache_dir %s: %s",
- cache_dir,
- e,
- )
-
- # Check default HF cache as well
- if not found_local_snapshot_dir:
- try:
- rev_dir = find_local_repo_dir(model_name_or_path, revision)
- if rev_dir and os.path.isdir(rev_dir):
- found_local_snapshot_dir = rev_dir
- except Exception as e:
- logger.warning("Failed to find local snapshot in default HF cache: %s", e)
-
- # If local snapshot exists, validate it contains at least one weight file
- # matching allow_patterns before skipping download.
- if found_local_snapshot_dir is None:
- return None
-
- local_weight_files: List[str] = []
- try:
- for pattern in allow_patterns:
- local_weight_files.extend(
- glob.glob(os.path.join(found_local_snapshot_dir, pattern))
- )
- except Exception as e:
- logger.warning(
- "Failed to scan local snapshot %s with patterns %s: %s",
- found_local_snapshot_dir,
- allow_patterns,
- e,
- )
- local_weight_files = []
-
- if len(local_weight_files) > 0:
- logger.info(
- "Found local HF snapshot for %s at %s; skipping download.",
- model_name_or_path,
- found_local_snapshot_dir,
- )
- return found_local_snapshot_dir
- else:
- logger.info(
- "Local HF snapshot at %s has no files matching %s; will attempt download.",
- found_local_snapshot_dir,
- allow_patterns,
- )
- return None
-
-
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
@@ -343,16 +260,6 @@ def download_weights_from_hf(
Returns:
str: The path to the downloaded model weights.
"""
-
- if is_in_ci():
- # If the weights are already local, skip downloading and returns the path.
- # This is used to skip too-many Huggingface API calls in CI.
- path = find_local_hf_snapshot_dir(
- model_name_or_path, cache_dir, allow_patterns, revision
- )
- if path is not None:
- return path
-
if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available:
fs = HfFileSystem()
diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py
index b6063aa2c..0797f4f6f 100644
--- a/python/sglang/srt/models/bailing_moe.py
+++ b/python/sglang/srt/models/bailing_moe.py
@@ -45,12 +45,12 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
- is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
+ ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -72,10 +72,6 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
-from sglang.srt.models.utils import (
- create_fused_set_kv_buffer_arg,
- enable_fused_set_kv_buffer,
-)
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None
@@ -559,27 +555,8 @@ class BailingMoEAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
- q, k = self.rotary_emb(
- positions,
- q,
- k,
- fused_set_kv_buffer_arg=(
- create_fused_set_kv_buffer_arg(
- value=v,
- layer=self.attn,
- forward_batch=forward_batch,
- )
- if enable_fused_set_kv_buffer(forward_batch)
- else None
- ),
- )
- context_layer = self.attn(
- q,
- k,
- v,
- forward_batch,
- save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
- )
+ q, k = self.rotary_emb(positions, q, k)
+ context_layer = self.attn(q, k, v, forward_batch)
attn_output, _ = self.dense(context_layer)
return attn_output
@@ -725,7 +702,7 @@ class BailingMoEModel(nn.Module):
self.embed_dim,
quant_config=quant_config,
prefix=add_prefix("word_embeddings", prefix),
- enable_tp=not is_dp_attention_enabled(),
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
else:
self.word_embeddings = PPMissingLayer()
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index 0db0ca164..3aab5e282 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -15,6 +15,7 @@
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
+from __future__ import annotations
import concurrent.futures
import logging
@@ -25,9 +26,15 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
-from tqdm import tqdm
from transformers import PretrainedConfig
+from sglang.srt.configs.model_config import (
+ get_nsa_index_head_dim,
+ get_nsa_index_n_heads,
+ get_nsa_index_topk,
+ is_deepseek_nsa,
+)
+from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
@@ -47,6 +54,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
NPUFusedMLAPreprocess,
is_mla_preprocess_enabled,
)
+from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
@@ -175,6 +183,11 @@ if _is_hip:
decode_attention_fwd_grouped_rope,
)
+if _is_npu:
+ import custom_ops
+ import sgl_kernel_npu
+ import torch_npu
+
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
@@ -183,6 +196,7 @@ logger = logging.getLogger(__name__)
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
"fa3",
+ "nsa",
"flashinfer",
"cutlass_mla",
"trtllm_mla",
@@ -203,6 +217,9 @@ class AttnForwardMethod(IntEnum):
# Use absorbed multi-latent attention
MLA = auto()
+ # Use Deepseek V3.2 sparse multi-latent attention
+ NPU_MLA_SPARSE = auto()
+
# Use multi-head attention, but with KV cache chunked.
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto()
@@ -245,9 +262,15 @@ def handle_ascend(attn, forward_batch):
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
- return AttnForwardMethod.MHA
+ if hasattr(attn, "indexer"):
+ return AttnForwardMethod.NPU_MLA_SPARSE
+ else:
+ return AttnForwardMethod.MHA
else:
- return AttnForwardMethod.MLA
+ if hasattr(attn, "indexer"):
+ return AttnForwardMethod.NPU_MLA_SPARSE
+ else:
+ return AttnForwardMethod.MLA
def _get_sum_extend_prefix_lens(forward_batch):
@@ -266,7 +289,7 @@ def _is_extend_without_speculative(forward_batch):
)
-def _handle_backend(attn, forward_batch, backend_name):
+def _handle_backend(attn: DeepseekV2AttentionMLA, forward_batch, backend_name):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
disable_ragged = (
backend_name in ["flashinfer", "flashmla"]
@@ -332,6 +355,10 @@ def handle_aiter(attn, forward_batch):
return AttnForwardMethod.MLA
+def handle_nsa(attn, forward_batch):
+ return AttnForwardMethod.MLA
+
+
def handle_triton(attn, forward_batch):
if (
_is_extend_without_speculative(forward_batch)
@@ -996,6 +1023,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
+ # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
+ if rope_scaling:
+ rope_scaling["rope_type"] = "deepseek_yarn"
+
# For tensor parallel attention
if self.q_lora_rank is not None:
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
@@ -1033,6 +1064,26 @@ class DeepseekV2AttentionMLA(nn.Module):
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
+ self.use_nsa = is_deepseek_nsa(config)
+ if self.use_nsa:
+ self.indexer = Indexer(
+ hidden_size=hidden_size,
+ index_n_heads=get_nsa_index_n_heads(config),
+ index_head_dim=get_nsa_index_head_dim(config),
+ rope_head_dim=qk_rope_head_dim,
+ index_topk=get_nsa_index_topk(config),
+ q_lora_rank=q_lora_rank,
+ max_position_embeddings=max_position_embeddings,
+ rope_theta=rope_theta,
+ scale_fmt="ue8m0",
+ block_size=128,
+ rope_scaling=rope_scaling,
+ prefix=add_prefix("indexer", prefix),
+ quant_config=quant_config,
+ layer_id=layer_id,
+ alt_stream=alt_stream,
+ )
+
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -1055,9 +1106,6 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
- if rope_scaling:
- rope_scaling["rope_type"] = "deepseek_yarn"
-
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
@@ -1184,8 +1232,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
if self.is_mla_preprocess_enabled:
assert (
- quant_config.get_name() == "w8a8_int8"
- ), "MLA Preprocess only works with W8A8Int8"
+ quant_config is None or quant_config.get_name() == "w8a8_int8"
+ ), "MLA Preprocess only works with Unquant or W8A8Int8"
self.mla_preprocess = None
def dispatch_attn_forward_method(
@@ -1263,7 +1311,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return hidden_states, None, forward_batch, None
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
-
if attn_forward_method == AttnForwardMethod.MHA:
inner_state = self.forward_normal_prepare(
positions, hidden_states, forward_batch, zero_allocator
@@ -1295,6 +1342,10 @@ class DeepseekV2AttentionMLA(nn.Module):
inner_state = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator
)
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
+ inner_state = self.forward_npu_sparse_prepare(
+ positions, hidden_states, forward_batch, zero_allocator
+ )
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch, zero_allocator
@@ -1320,6 +1371,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_normal_chunked_kv_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb_core(*inner_state)
+ elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
+ return self.forward_npu_sparse_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
@@ -1412,6 +1465,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
+ q_lora = None
if self.q_lora_rank is not None:
if (
(not isinstance(hidden_states, tuple))
@@ -1450,6 +1504,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
+ # q_lora needed by indexer
+ if self.use_nsa:
+ q_lora = q
+
k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
@@ -1519,10 +1577,37 @@ class DeepseekV2AttentionMLA(nn.Module):
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
- return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
+ topk_indices = None
+ if q_lora is not None:
+ topk_indices = self.indexer(
+ x=hidden_states,
+ q_lora=q_lora,
+ positions=positions,
+ forward_batch=forward_batch,
+ layer_id=self.layer_id,
+ )
+
+ return (
+ q_pe,
+ k_pe,
+ q_nope_out,
+ k_nope,
+ forward_batch,
+ zero_allocator,
+ positions,
+ topk_indices,
+ )
def forward_absorb_core(
- self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
+ self,
+ q_pe,
+ k_pe,
+ q_nope_out,
+ k_nope,
+ forward_batch,
+ zero_allocator,
+ positions,
+ topk_indices,
):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
extra_args = {}
@@ -1531,6 +1616,7 @@ class DeepseekV2AttentionMLA(nn.Module):
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
"is_neox": self.rotary_emb.is_neox_style,
}
+
attn_output = self.attn_mqa(
q_nope_out,
k_nope,
@@ -1539,6 +1625,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_rope=q_pe,
k_rope=k_pe,
**extra_args,
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
)
else:
if _use_aiter_gfx95:
@@ -1558,7 +1645,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
- attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
+ attn_output = self.attn_mqa(
+ q,
+ k,
+ k_nope,
+ forward_batch,
+ **(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
+ )
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.use_deep_gemm_bmm:
@@ -1640,6 +1733,221 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
+ def forward_npu_sparse_prepare(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ forward_batch: ForwardBatch,
+ zero_allocator: BumpAllocator,
+ ):
+ """
+ Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
+ """
+ if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
+ if self.mla_preprocess is None:
+ self.mla_preprocess = NPUFusedMLAPreprocess(
+ self.fused_qkv_a_proj_with_mqa,
+ self.q_a_layernorm,
+ self.kv_a_layernorm,
+ self.q_b_proj,
+ self.w_kc,
+ self.rotary_emb,
+ self.layer_id,
+ self.num_local_heads,
+ self.qk_nope_head_dim,
+ self.qk_rope_head_dim,
+ )
+ (
+ q_pe,
+ k_pe,
+ q_nope_out,
+ k_nope,
+ forward_batch,
+ zero_allocator,
+ positions,
+ ) = self.mla_preprocess.forward(
+ positions, hidden_states, forward_batch, zero_allocator
+ )
+
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
+ q, _ = fused_qkv_a_proj_out.split(
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
+ )
+ q_lora = self.q_a_layernorm(q)
+ else:
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
+
+ if (
+ (not isinstance(hidden_states, tuple))
+ and hidden_states.shape[0] <= 16
+ and self.use_min_latency_fused_a_gemm
+ ):
+ fused_qkv_a_proj_out = dsv3_fused_a_gemm(
+ hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
+ )
+ else:
+ fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
+ q, latent_cache = fused_qkv_a_proj_out.split(
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
+ )
+ k_nope = latent_cache[..., : self.kv_lora_rank]
+
+ # overlap qk norm
+ if self.alt_stream is not None and get_is_capture_mode():
+ current_stream = torch.cuda.current_stream()
+ self.alt_stream.wait_stream(current_stream)
+ q = self.q_a_layernorm(q)
+ with torch.cuda.stream(self.alt_stream):
+ k_nope = self.kv_a_layernorm(k_nope)
+ current_stream.wait_stream(self.alt_stream)
+ else:
+ if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
+ q, k_nope = fused_rms_mxfp4_quant(
+ q,
+ self.q_a_layernorm.weight,
+ self.q_a_layernorm.variance_epsilon,
+ k_nope,
+ self.kv_a_layernorm.weight,
+ self.kv_a_layernorm.variance_epsilon,
+ )
+ else:
+ q = self.q_a_layernorm(q)
+ k_nope = self.kv_a_layernorm(k_nope)
+
+ q_lora = q.clone() # required for topk_indices
+ k_nope = k_nope.unsqueeze(1)
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
+
+ q_nope, q_pe = q.split(
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
+
+ if self.use_deep_gemm_bmm:
+ q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
+ per_token_group_quant_mla_deep_gemm_masked_fp8(
+ q_nope.transpose(0, 1)
+ )
+ )
+ q_nope_out = q_nope.new_empty(
+ (self.num_local_heads, aligned_m, self.kv_lora_rank)
+ )
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
+ (q_nope_val, q_nope_scale),
+ (self.w_kc, self.w_scale_k),
+ q_nope_out,
+ masked_m,
+ expected_m,
+ )
+ q_nope_out = q_nope_out[:, :expected_m, :]
+ elif _is_hip:
+ # TODO(haishaw): add bmm_fp8 to ROCm
+ if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
+ x = q_nope.transpose(0, 1)
+ q_nope_out = torch.empty(
+ x.shape[0],
+ x.shape[1],
+ self.w_kc.shape[2],
+ device=x.device,
+ dtype=torch.bfloat16,
+ )
+ batched_gemm_afp4wfp4_pre_quant(
+ x,
+ self.w_kc.transpose(-2, -1),
+ self.w_scale_k.transpose(-2, -1),
+ torch.bfloat16,
+ q_nope_out,
+ )
+ else:
+ q_nope_out = torch.bmm(
+ q_nope.to(torch.bfloat16).transpose(0, 1),
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
+ )
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
+ q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
+ q_nope.transpose(0, 1),
+ zero_allocator.allocate(1),
+ )
+ q_nope_out = bmm_fp8(
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
+ )
+ else:
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
+
+ q_nope_out = q_nope_out.transpose(0, 1)
+
+ if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
+ not _use_aiter or not _is_gfx95_supported
+ ):
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
+
+ # TODO: multi-stream indexer
+ topk_indices = self.indexer(
+ hidden_states, q_lora, positions, forward_batch, self.layer_id
+ )
+
+ return (
+ q_pe,
+ k_pe,
+ q_nope_out,
+ k_nope,
+ topk_indices,
+ forward_batch,
+ zero_allocator,
+ positions,
+ )
+
+ def forward_npu_sparse_core(
+ self,
+ q_pe,
+ k_pe,
+ q_nope_out,
+ k_nope,
+ topk_indices,
+ forward_batch,
+ zero_allocator,
+ positions,
+ ):
+ attn_output = self.attn_mqa(
+ q_nope_out.contiguous(),
+ k_nope.contiguous(),
+ k_nope.contiguous(),
+ forward_batch,
+ save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
+ q_rope=q_pe.contiguous(),
+ k_rope=k_pe.contiguous(),
+ topk_indices=topk_indices,
+ )
+ attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
+
+ attn_bmm_output = torch.empty(
+ (attn_output.shape[0], self.num_local_heads, self.v_head_dim),
+ dtype=attn_output.dtype,
+ device=attn_output.device,
+ )
+
+ if not forward_batch.forward_mode.is_decode():
+ attn_output = attn_output.transpose(0, 1)
+ torch.bmm(
+ attn_output,
+ self.w_vc,
+ out=attn_bmm_output.view(
+ -1, self.num_local_heads, self.v_head_dim
+ ).transpose(0, 1),
+ )
+ else:
+ attn_output = attn_output.contiguous()
+ torch.ops.npu.batch_matmul_transpose(
+ attn_output, self.w_vc, attn_bmm_output
+ )
+
+ attn_bmm_output = attn_bmm_output.reshape(
+ -1, self.num_local_heads * self.v_head_dim
+ )
+
+ output, _ = self.o_proj(attn_bmm_output)
+ return output
+
def forward_absorb_fused_mla_rope_prepare(
self,
positions: torch.Tensor,
@@ -2121,7 +2429,6 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor:
-
quant_format = (
"mxfp4"
if _is_gfx95_supported
@@ -2704,7 +3011,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
)
self_attn.w_vc = bind_or_assign(
- self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
+ self_attn.w_vc, w_vc.contiguous().transpose(1, 2).contiguous()
)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
@@ -3086,6 +3393,7 @@ BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
BackendRegistry.register("fa4", handle_fa4)
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
BackendRegistry.register("aiter", handle_aiter)
+BackendRegistry.register("nsa", handle_nsa)
BackendRegistry.register("triton", handle_triton)
@@ -3093,4 +3401,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
-EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
+class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
+ pass
+
+
+EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]
diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py
index a1c3bc0b1..5b6145aff 100644
--- a/python/sglang/srt/models/gemma3_causal.py
+++ b/python/sglang/srt/models/gemma3_causal.py
@@ -20,6 +20,7 @@ import torch.nn.functional as F
from torch import nn
from transformers import (
ROPE_INIT_FUNCTIONS,
+ AutoModel,
Gemma3TextConfig,
PretrainedConfig,
PreTrainedModel,
@@ -760,3 +761,4 @@ class Gemma3ForCausalLM(PreTrainedModel):
EntryClass = Gemma3ForCausalLM
+AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py
index d4cc9e1e6..867ffe91b 100644
--- a/python/sglang/srt/models/glm4_moe.py
+++ b/python/sglang/srt/models/glm4_moe.py
@@ -12,7 +12,7 @@
# limitations under the License.
# ==============================================================================
-"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
+"""Inference-only GLM-4.5 model compatible with HuggingFace weights"""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple
@@ -785,9 +785,9 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
or self.config.architectures[0] != architecture
or self.config.n_shared_experts != 1
):
- disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
+ disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
- disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
+ disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py
index 4816f5775..399f0f4e0 100644
--- a/python/sglang/srt/models/glm4_moe_nextn.py
+++ b/python/sglang/srt/models/glm4_moe_nextn.py
@@ -12,7 +12,7 @@
# limitations under the License.
# ==============================================================================
-"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
+"""Inference-only GLM-4.5 NextN Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple
@@ -48,7 +48,7 @@ class Glm4MoeModelNextN(nn.Module):
super().__init__()
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
logger.warning(
- "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 / GLM-4.6 model."
+ "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
)
quant_config = None
diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py
index 982400514..7231a5d75 100644
--- a/python/sglang/srt/models/gpt_oss.py
+++ b/python/sglang/srt/models/gpt_oss.py
@@ -66,10 +66,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
-from sglang.srt.models.utils import (
- create_fused_set_kv_buffer_arg,
- enable_fused_set_kv_buffer,
-)
from sglang.srt.utils import (
LazyValue,
add_prefix,
@@ -197,6 +193,33 @@ class GptOssSparseMoeBlock(nn.Module):
return ans
+def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
+ """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
+ return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
+
+
+# TODO maybe move to a model-common utils
+def _create_fused_set_kv_buffer_arg(
+ value: torch.Tensor,
+ layer: RadixAttention,
+ forward_batch: ForwardBatch,
+):
+ layer_id = layer.layer_id
+ token_to_kv_pool = forward_batch.token_to_kv_pool
+
+ k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
+ v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
+
+ return FusedSetKVBufferArg(
+ value=value,
+ k_buffer=k_buffer.view(k_buffer.shape[0], -1),
+ v_buffer=v_buffer.view(v_buffer.shape[0], -1),
+ k_scale=layer.k_scale,
+ v_scale=layer.v_scale,
+ cache_loc=forward_batch.out_cache_loc,
+ )
+
+
class GptOssAttention(nn.Module):
def __init__(
self,
@@ -314,12 +337,12 @@ class GptOssAttention(nn.Module):
q,
k,
fused_set_kv_buffer_arg=(
- create_fused_set_kv_buffer_arg(
+ _create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
- if enable_fused_set_kv_buffer(forward_batch)
+ if _enable_fused_set_kv_buffer(forward_batch)
else None
),
)
@@ -333,7 +356,7 @@ class GptOssAttention(nn.Module):
attn_output = self.attn(
*inner_state,
sinks=self.sinks,
- save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
+ save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
)
output, _ = self.o_proj(attn_output)
return output
diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py
index 72077d96a..f0184390c 100644
--- a/python/sglang/srt/models/mllama4.py
+++ b/python/sglang/srt/models/mllama4.py
@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
- hidden_states = hidden_states.permute(0, 2, 1).contiguous()
+ hidden_states = hidden_states.permute(0, 2, 1)
hidden_states, _ = self.linear(hidden_states)
return hidden_states
@@ -446,20 +446,9 @@ class Llama4ForConditionalGeneration(nn.Module):
)
if self.has_vision:
- # TODO: make this more general
- ignore_quant_layers = getattr(config, "quantization_config", {}).get(
- "ignore", {}
- )
- if (
- "model.layers.vision_model*" in ignore_quant_layers
- and "model.layers.multi_modal_projector*" in ignore_quant_layers
- ):
- vision_quant_config = None
- else:
- vision_quant_config = quant_config
self.vision_model = Llama4VisionModel(
config.vision_config,
- quant_config=vision_quant_config,
+ quant_config=quant_config,
prefix=add_prefix("vision_model", prefix),
)
@@ -571,7 +560,7 @@ class Llama4ForConditionalGeneration(nn.Module):
forward_batch=forward_batch,
language_model=self.language_model,
data_embedding_funcs={
- Modality.IMAGE: image_embedding_func,
+ Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py
index 531f5b6e9..256caee9c 100644
--- a/python/sglang/srt/models/qwen2.py
+++ b/python/sglang/srt/models/qwen2.py
@@ -454,6 +454,9 @@ class Qwen2ForCausalLM(nn.Module):
# For EAGLE3 support
self.capture_aux_hidden_states = False
+ # For EAGLE3 support
+ self.capture_aux_hidden_states = False
+
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embedding(input_ids)
diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py
index 6c70629c2..9afb2b1ab 100644
--- a/python/sglang/srt/models/qwen2_5_vl.py
+++ b/python/sglang/srt/models/qwen2_5_vl.py
@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
self.window_size = vision_config.window_size
self.patch_size = vision_config.patch_size
- mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
+ mlp_hidden_size: int = vision_config.intermediate_size
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py
index 32bda876a..bc5f054d7 100644
--- a/python/sglang/srt/models/qwen3.py
+++ b/python/sglang/srt/models/qwen3.py
@@ -1,5 +1,6 @@
# Adapted from qwen2.py
import logging
+from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
@@ -29,19 +30,12 @@ from sglang.srt.model_loader.weight_utils import (
)
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
-from sglang.srt.utils import (
- add_prefix,
- get_cmo_stream,
- is_cuda,
- is_npu,
- wait_cmo_stream,
-)
+from sglang.srt.utils import add_prefix, is_cuda
Qwen3Config = None
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
-_is_npu = is_npu()
class Qwen3Attention(nn.Module):
@@ -241,18 +235,9 @@ class Qwen3DecoderLayer(nn.Module):
# Fully Connected
hidden_states, residual = self.layer_communicator.prepare_mlp(
- hidden_states,
- residual,
- forward_batch,
- cache=(
- [self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
- if _is_npu
- else None
- ),
+ hidden_states, residual, forward_batch
)
hidden_states = self.mlp(hidden_states)
- if _is_npu and get_cmo_stream():
- wait_cmo_stream()
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py
index d9ac4684e..9d92a3b13 100644
--- a/python/sglang/srt/models/qwen3_moe.py
+++ b/python/sglang/srt/models/qwen3_moe.py
@@ -60,10 +60,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
-from sglang.srt.models.utils import (
- create_fused_set_kv_buffer_arg,
- enable_fused_set_kv_buffer,
-)
from sglang.srt.utils import (
add_prefix,
is_cuda,
@@ -416,20 +412,7 @@ class Qwen3MoeAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
- q, k = self.rotary_emb(
- positions,
- q,
- k,
- fused_set_kv_buffer_arg=(
- create_fused_set_kv_buffer_arg(
- value=v,
- layer=self.attn,
- forward_batch=forward_batch,
- )
- if enable_fused_set_kv_buffer(forward_batch)
- else None
- ),
- )
+ q, k = self.rotary_emb(positions, q, k)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
@@ -437,10 +420,7 @@ class Qwen3MoeAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
- attn_output = self.attn(
- *inner_state,
- save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
- )
+ attn_output = self.attn(*inner_state)
output, _ = self.o_proj(attn_output)
return output
diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py
deleted file mode 100644
index a87d21e78..000000000
--- a/python/sglang/srt/models/qwen3_vl.py
+++ /dev/null
@@ -1,787 +0,0 @@
-# Copyright 2025 Qwen Team
-# Copyright 2025 SGLang Team
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
-import logging
-from functools import lru_cache, partial
-from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-from transformers.activations import ACT2FN
-from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
- Qwen2_5_VisionRotaryEmbedding,
-)
-
-from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
-from sglang.srt.hf_transformers_utils import get_processor
-from sglang.srt.layers.attention.vision import VisionAttention
-from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
-from sglang.srt.layers.logits_processor import LogitsProcessor
-from sglang.srt.layers.pooler import Pooler, PoolingType
-from sglang.srt.layers.quantization.base_config import QuantizationConfig
-from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
-from sglang.srt.managers.mm_utils import (
- MultiModalityDataPaddingPatternMultimodalTokens,
- general_mm_embed_routine,
-)
-from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
-from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
-from sglang.srt.model_loader.weight_utils import default_weight_loader
-from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
-from sglang.srt.models.qwen3 import Qwen3Model
-from sglang.srt.utils import add_prefix
-
-logger = logging.getLogger(__name__)
-
-# === Vision Encoder === #
-
-
-class Qwen3_VisionMLP(nn.Module):
-
- def __init__(
- self,
- in_features: int,
- hidden_features: int,
- bias: bool = True,
- hidden_act="silu",
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super().__init__()
- self.linear_fc1 = ColumnParallelLinear(
- in_features,
- hidden_features,
- bias=bias,
- quant_config=quant_config,
- prefix=add_prefix("linear_fc1", prefix),
- )
- self.linear_fc2 = RowParallelLinear(
- hidden_features,
- in_features,
- bias=bias,
- quant_config=quant_config,
- prefix=add_prefix("linear_fc2", prefix),
- )
- self.act = ACT2FN[hidden_act]
-
- def forward(self, x: torch.Tensor):
- x_fc1, _ = self.linear_fc1(x)
- mlp_output, _ = self.linear_fc2(self.act(x_fc1))
- return mlp_output
-
-
-class Qwen3VLVisionPatchEmbed(nn.Module):
- def __init__(self, config) -> None:
- super().__init__()
- self.patch_size = config.patch_size
- self.temporal_patch_size = config.temporal_patch_size
- self.in_channels = config.in_channels
- self.embed_dim = config.hidden_size
-
- kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
- self.proj = nn.Conv3d(
- self.in_channels,
- self.embed_dim,
- kernel_size=kernel_size,
- stride=kernel_size,
- bias=True,
- )
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- target_dtype = self.proj.weight.dtype
- hidden_states = hidden_states.view(
- -1,
- self.in_channels,
- self.temporal_patch_size,
- self.patch_size,
- self.patch_size,
- )
- hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(
- -1, self.embed_dim
- )
- return hidden_states
-
-
-class Qwen3_VisionBlock(nn.Module):
-
- def __init__(
- self,
- dim: int,
- num_heads: int,
- intermediate_dim: int,
- hidden_act="silu",
- norm_layer: Optional[Callable[[int], nn.Module]] = None,
- attn_implementation: Optional[str] = "sdpa",
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- if norm_layer is None:
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
- self.norm1 = norm_layer(dim)
- self.norm2 = norm_layer(dim)
-
- if attn_implementation == "sdpa":
- softmax_in_single_precision = False
- qkv_backend = "sdpa"
- flatten_batch = True
- elif attn_implementation == "flash_attention_2":
- softmax_in_single_precision = False
- qkv_backend = "triton_attn"
- flatten_batch = True
- elif attn_implementation == "eager":
- softmax_in_single_precision = True
- qkv_backend = "sdpa"
- flatten_batch = True
- elif attn_implementation == "flash_attention_3":
- softmax_in_single_precision = False
- qkv_backend = "fa3"
- flatten_batch = True
-
- self.attn = VisionAttention(
- embed_dim=dim,
- num_heads=num_heads,
- projection_size=dim,
- use_qkv_parallel=True,
- rotary_embed="normal",
- proj_bias=True,
- qkv_backend=qkv_backend,
- softmax_in_single_precision=softmax_in_single_precision,
- flatten_batch=flatten_batch,
- quant_config=quant_config,
- prefix=add_prefix("attn", prefix),
- )
- self.mlp = Qwen3_VisionMLP(
- dim,
- intermediate_dim,
- hidden_act=hidden_act,
- bias=True,
- quant_config=quant_config,
- prefix=f"{prefix}.mlp",
- )
-
- def forward(
- self,
- x: torch.Tensor,
- cu_seqlens: torch.Tensor,
- position_embeddings: torch.Tensor,
- ) -> torch.Tensor:
- hidden_states = self.norm1(x)
- hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
- attn = self.attn(
- hidden_states,
- cu_seqlens=cu_seqlens,
- position_embeddings=position_embeddings,
- )
- attn = rearrange(attn, "b s ... -> s b ...")
- x = x + attn
- norm2 = self.norm2(x)
- mlp = self.mlp(norm2)
- x = x + mlp
- return x
-
-
-class Qwen3_VisionPatchMerger(nn.Module):
-
- def __init__(
- self,
- dim: int,
- context_dim: int,
- norm_layer: Optional[Callable[[int], nn.Module]] = None,
- spatial_merge_size: int = 2,
- use_postshuffle_norm: bool = False,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.hidden_size = context_dim * (spatial_merge_size**2)
-
- self.use_postshuffle_norm = use_postshuffle_norm
-
- if norm_layer is None:
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
- self.norm = norm_layer(
- self.hidden_size if use_postshuffle_norm else context_dim
- )
- self.linear_fc1 = ColumnParallelLinear(
- self.hidden_size,
- self.hidden_size,
- bias=True,
- quant_config=quant_config,
- prefix=add_prefix("linear_fc1", prefix),
- )
- self.act_fn = nn.GELU()
- self.linear_fc2 = RowParallelLinear(
- self.hidden_size,
- dim,
- bias=True,
- quant_config=quant_config,
- prefix=add_prefix("linear_fc2", prefix),
- )
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.use_postshuffle_norm:
- x = self.norm(x.view(-1, self.hidden_size))
- else:
- x = self.norm(x).view(-1, self.hidden_size)
-
- x_parallel, _ = self.linear_fc1(x)
- x_parallel = self.act_fn(x_parallel)
- out, _ = self.linear_fc2(x_parallel)
- return out
-
-
-class Qwen3_VisionTransformer(nn.Module):
-
- def __init__(
- self,
- vision_config: Qwen3VLVisionConfig,
- norm_eps: float = 1e-6,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.hidden_size = vision_config.hidden_size
- self.num_heads = vision_config.num_heads
- self.num_position_embeddings = vision_config.num_position_embeddings
- self.patch_size = vision_config.patch_size
- self.spatial_merge_size = vision_config.spatial_merge_size
- self.spatial_merge_unit = self.spatial_merge_size**2
- self.temporal_patch_size = vision_config.temporal_patch_size
- self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
- self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
- self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
-
- norm_layer = partial(nn.LayerNorm, eps=norm_eps)
- head_dim = self.hidden_size // self.num_heads
- self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
-
- self.blocks = nn.ModuleList(
- [
- Qwen3_VisionBlock(
- dim=self.hidden_size,
- num_heads=self.num_heads,
- intermediate_dim=vision_config.intermediate_size,
- hidden_act=vision_config.hidden_act,
- norm_layer=norm_layer,
- attn_implementation="flash_attention_3",
- quant_config=quant_config,
- prefix=add_prefix(f"blocks.{layer_idx}", prefix),
- )
- for layer_idx in range(vision_config.depth)
- ]
- )
- self.merger = Qwen3_VisionPatchMerger(
- dim=vision_config.out_hidden_size,
- context_dim=self.hidden_size,
- norm_layer=norm_layer,
- spatial_merge_size=self.spatial_merge_size,
- quant_config=quant_config,
- prefix=add_prefix("merger", prefix),
- )
-
- self.deepstack_merger_list = nn.ModuleList(
- [
- Qwen3_VisionPatchMerger(
- dim=vision_config.out_hidden_size,
- context_dim=self.hidden_size,
- spatial_merge_size=self.spatial_merge_size,
- use_postshuffle_norm=True,
- norm_layer=norm_layer,
- quant_config=quant_config,
- prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix),
- )
- for layer_idx in range(len(self.deepstack_visual_indexes))
- ]
- )
-
- @property
- def dtype(self) -> torch.dtype:
- return self.patch_embed.proj.weight.dtype
-
- @property
- def device(self) -> torch.device:
- return self.patch_embed.proj.weight.device
-
- def rot_pos_emb(self, grid_thw):
- pos_ids = []
- for t, h, w in grid_thw:
- hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
- hpos_ids = hpos_ids.reshape(
- h // self.spatial_merge_size,
- self.spatial_merge_size,
- w // self.spatial_merge_size,
- self.spatial_merge_size,
- )
- hpos_ids = hpos_ids.permute(0, 2, 1, 3)
- hpos_ids = hpos_ids.flatten()
-
- wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
- wpos_ids = wpos_ids.reshape(
- h // self.spatial_merge_size,
- self.spatial_merge_size,
- w // self.spatial_merge_size,
- self.spatial_merge_size,
- )
- wpos_ids = wpos_ids.permute(0, 2, 1, 3)
- wpos_ids = wpos_ids.flatten()
- pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
- pos_ids = torch.cat(pos_ids, dim=0)
- max_grid_size = grid_thw[:, 1:].max()
- rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
- rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
- return rotary_pos_emb
-
- def fast_pos_embed_interpolate(self, grid_thw):
- num_grid_per_side = int(self.num_position_embeddings**0.5)
-
- idx_list = [[] for _ in range(4)]
- weight_list = [[] for _ in range(4)]
-
- # TODO: use torch instand of np
- for t, h, w in grid_thw:
- h_idxs = np.linspace(0, num_grid_per_side - 1, h)
- w_idxs = np.linspace(0, num_grid_per_side - 1, w)
-
- h_idxs_floor = h_idxs.astype(int)
- w_idxs_floor = w_idxs.astype(int)
- h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
- w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1)
-
- dh = h_idxs - h_idxs_floor
- dw = w_idxs - w_idxs_floor
-
- idx_list[0].extend(
- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None])
- .flatten()
- .tolist()
- * t
- )
- idx_list[1].extend(
- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None])
- .flatten()
- .tolist()
- * t
- )
- idx_list[2].extend(
- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None])
- .flatten()
- .tolist()
- * t
- )
- idx_list[3].extend(
- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None])
- .flatten()
- .tolist()
- * t
- )
-
- weight_list[0].extend(
- ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t
- )
- weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
- weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
- weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t)
-
- device = self.pos_embed.weight.device
- dtype = self.pos_embed.weight.dtype
-
- p0 = (
- self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device))
- * torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None]
- )
- p1 = (
- self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device))
- * torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None]
- )
- p2 = (
- self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device))
- * torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None]
- )
- p3 = (
- self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device))
- * torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None]
- )
-
- patch_pos_embeds = p0 + p1 + p2 + p3
- patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw])
- patch_pos_embeds_permute = []
- m_size = self.spatial_merge_size
- for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
- pos_embed = (
- pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1)
- .permute(0, 1, 3, 2, 4, 5)
- .flatten(0, 4)
- )
- patch_pos_embeds_permute.append(pos_embed)
- patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
- return patch_pos_embeds
-
- def forward(
- self,
- x: torch.Tensor,
- grid_thw: torch.Tensor,
- ) -> torch.Tensor:
- x = x.to(device=self.device, dtype=self.dtype)
- x = self.patch_embed(x)
-
- pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
- x = x + pos_embeds
- rotary_pos_emb = self.rot_pos_emb(grid_thw)
-
- seq_len, _ = x.size()
- rotary_pos_emb = rotary_pos_emb.to(x.device)
-
- rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
- emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
- position_embeddings = (emb.cos(), emb.sin())
-
- # compute cu_seqlens
- cu_seqlens = torch.cat(
- [
- torch.tensor([0], device=grid_thw.device),
- (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
- ]
- )
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
-
- # max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
- x = x.unsqueeze(1)
-
- deepstack_feature_lists = []
- num_deepstack_captured = 0
- for layer_num, blk in enumerate(self.blocks):
- x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
- if layer_num in self.deepstack_visual_indexes:
- deepstack_feature = self.deepstack_merger_list[num_deepstack_captured](
- x
- )
- deepstack_feature_lists.append(deepstack_feature)
- num_deepstack_captured += 1
- x = self.merger(x)
- hidden_states = torch.cat(
- [x] + deepstack_feature_lists, dim=1
- ) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
- return hidden_states
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("attn.qkv.", "attn.q.", "q"),
- ("attn.qkv.", "attn.k.", "k"),
- ("attn.qkv.", "attn.v.", "v"),
- ]
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- loaded_params: set[str] = set()
-
- for name, loaded_weight in weights:
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
-
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- return loaded_params
-
-
-cached_get_processor = lru_cache(get_processor)
-
-
-class Qwen3LLMModel(Qwen3Model):
-
- def __init__(
- self,
- *,
- config: Qwen3VLConfig,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super().__init__(config=config, quant_config=quant_config, prefix=prefix)
- if not self.pp_group.is_first_rank:
- assert self.start_layer >= len(
- config.vision_config.deepstack_visual_indexes
- ), "start_layer should be greater than or equal to len(deepstack_visual_indexes)"
-
- self.hidden_size = config.hidden_size
- self.deepstack_embed_to_decoder_layer = range(
- len(config.vision_config.deepstack_visual_indexes)
- )
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- forward_batch: ForwardBatch,
- input_embeds: torch.Tensor = None,
- pp_proxy_tensors: Optional[PPProxyTensors] = None,
- input_deepstack_embeds: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, PPProxyTensors]:
-
- if self.pp_group.is_first_rank:
- if input_embeds is None:
- hidden_states = self.embed_tokens(input_ids)
- else:
- hidden_states = input_embeds
- residual = None
- else:
- assert pp_proxy_tensors is not None
- hidden_states = pp_proxy_tensors["hidden_states"]
- residual = pp_proxy_tensors["residual"]
-
- aux_hidden_states = []
- for layer_idx, layer in enumerate(
- self.layers[self.start_layer : self.end_layer]
- ):
- layer_idx = layer_idx + self.start_layer
- if layer_idx in self.layers_to_capture:
- aux_hidden_states.append(
- hidden_states + residual if residual is not None else hidden_states
- )
-
- hidden_states, residual = layer(
- positions,
- hidden_states,
- forward_batch,
- residual,
- )
-
- # process deepstack
- if (
- input_deepstack_embeds is not None
- and layer_idx in self.deepstack_embed_to_decoder_layer
- ):
- sep = self.hidden_size * layer_idx
- hidden_states = (
- hidden_states
- + input_deepstack_embeds[:, sep : sep + self.hidden_size]
- )
-
- if not self.pp_group.is_last_rank:
- return PPProxyTensors(
- {
- "hidden_states": hidden_states,
- "residual": residual,
- }
- )
- else:
- if hidden_states.shape[0] != 0:
- if residual is None:
- hidden_states = self.norm(hidden_states)
- else:
- hidden_states, _ = self.norm(hidden_states, residual)
-
- if len(aux_hidden_states) == 0:
- return hidden_states
-
- return hidden_states, aux_hidden_states
-
-
-class Qwen3VLForConditionalGeneration(nn.Module):
- def __init__(
- self,
- config: Qwen3VLConfig,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
-
- self.config = config
- self.visual = Qwen3_VisionTransformer(
- config.vision_config,
- norm_eps=getattr(config, "rms_norm_eps", 1e-6),
- # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
- # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
- quant_config=quant_config,
- prefix=add_prefix("visual", prefix),
- )
-
- self.model = Qwen3LLMModel(
- config=config,
- quant_config=quant_config,
- prefix=add_prefix("model", prefix),
- )
-
- if config.tie_word_embeddings:
- self.lm_head = self.model.embed_tokens
- else:
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=add_prefix("lm_head", prefix),
- )
- self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
-
- self.logits_processor = LogitsProcessor(config)
- self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
- # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
- # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states
-
- # deepstack
- self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
- self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
-
- @property
- def use_deepstack(self) -> bool:
- return hasattr(self, "deepstack_visual_indexes")
-
- def separate_deepstack_embeds(self, embedding):
- assert (
- embedding.shape[-1] % (1 + self.num_deepstack_embeddings) == 0
- ), f"hidden_state of {embedding.shape} should be divisible by ({1 + self.num_deepstack_embeddings})"
-
- separate_index = self.config.hidden_size
- input_embeds = embedding[:, :separate_index]
- input_deepstack_embeds = embedding[:, separate_index:]
- return input_embeds, input_deepstack_embeds
-
- def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
- pattern = MultiModalityDataPaddingPatternMultimodalTokens()
- return pattern.pad_input_tokens(input_ids, mm_inputs)
-
- def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
- # in qwen-vl, last dim is the same
- pixel_values = torch.cat([item.feature for item in items], dim=0).type(
- self.visual.dtype
- )
- image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
- assert pixel_values.dim() == 2, pixel_values.dim()
- assert image_grid_thw.dim() == 2, image_grid_thw.dim()
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
- return image_embeds
-
- def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
- # in qwen-vl, last dim is the same
- pixel_values = torch.cat([item.feature for item in items], dim=0).type(
- self.visual.dtype
- )
- video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
- assert pixel_values.dim() == 2, pixel_values.dim()
- assert video_grid_thw.dim() == 2, video_grid_thw.dim()
- video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
- return video_embeds
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- forward_batch: ForwardBatch,
- get_embedding: bool = False,
- ):
- """Run forward pass for Qwen3-VL.
-
- Args:
- input_ids: Flattened (concatenated) input_ids corresponding to a
- batch.
- positions: Flattened (concatenated) position ids corresponding to a
- batch.
- **NOTE**: If mrope is enabled (default setting for Qwen2-VL
- opensource models), the shape will be `(3, seq_len)`,
- otherwise it will be `(seq_len,).
- (Use input_metadata.mrope_positions to replace it)
- """
- if self.is_mrope_enabled:
- positions = forward_batch.mrope_positions
-
- if not (
- forward_batch.forward_mode.is_decode()
- or not forward_batch.contains_image_inputs()
- ):
- if self.is_mrope_enabled:
- assert positions.ndim == 2 and positions.size(0) == 3, (
- "multimodal section rotary embedding requires "
- f"(3, seq_len) positions, but got {positions.size()}"
- )
-
- hidden_states = general_mm_embed_routine(
- input_ids=input_ids,
- forward_batch=forward_batch,
- language_model=self.model,
- multimodal_model=self,
- positions=positions,
- use_deepstack=self.use_deepstack,
- )
-
- if not get_embedding:
- return self.logits_processor(
- input_ids, hidden_states, self.lm_head, forward_batch
- )
- else:
- return self.pooler(hidden_states, forward_batch)
-
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- (".qkv_proj", ".q_proj", "q"),
- (".qkv_proj", ".k_proj", "k"),
- (".qkv_proj", ".v_proj", "v"),
- ("gate_up_proj", "up_proj", 1),
- ("gate_up_proj", "gate_proj", 0),
- ]
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- if "language_model" in name:
- name = name.replace(r"model.language_model.", r"model.")
-
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name:
- continue
- if "visual" in name:
- continue
- name = name.replace(weight_name, param_name)
-
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- if "visual" in name:
- # adapt to VisionAttention
- name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
- name = name.replace(r"model.visual.", r"visual.")
-
- try:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- except KeyError:
- print(params_dict.keys())
- raise
-
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
-
-
-EntryClass = Qwen3VLForConditionalGeneration
diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py
deleted file mode 100644
index a88059916..000000000
--- a/python/sglang/srt/models/qwen3_vl_moe.py
+++ /dev/null
@@ -1,471 +0,0 @@
-# Copyright 2025 Qwen Team
-# Copyright 2025 SGLang Team
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
-import logging
-from functools import lru_cache, partial
-from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-from transformers import BatchFeature
-from transformers.activations import ACT2FN
-from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
- Qwen2_5_VisionRotaryEmbedding,
-)
-
-from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
-from sglang.srt.distributed import (
- get_moe_expert_parallel_world_size,
- get_pp_group,
- get_tensor_model_parallel_rank,
-)
-from sglang.srt.hf_transformers_utils import get_processor
-from sglang.srt.layers.logits_processor import LogitsProcessor
-from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
-from sglang.srt.layers.pooler import Pooler, PoolingType
-from sglang.srt.layers.quantization.base_config import QuantizationConfig
-from sglang.srt.layers.utils import get_layer_id
-from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
-from sglang.srt.managers.mm_utils import (
- MultiModalityDataPaddingPatternMultimodalTokens,
- general_mm_embed_routine,
-)
-from sglang.srt.managers.schedule_batch import (
- MultimodalDataItem,
- MultimodalInputs,
- global_server_args_dict,
-)
-from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
-from sglang.srt.model_loader.weight_utils import default_weight_loader
-from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
-from sglang.srt.models.qwen3_vl import (
- Qwen3_VisionTransformer,
- Qwen3VLForConditionalGeneration,
-)
-from sglang.srt.utils import add_prefix
-
-logger = logging.getLogger(__name__)
-
-cached_get_processor = lru_cache(get_processor)
-
-
-class Qwen3MoeLLMModel(Qwen3MoeModel):
- def __init__(
- self,
- *,
- config: Qwen3VLMoeConfig,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super().__init__(config=config, quant_config=quant_config, prefix=prefix)
-
- self.hidden_size = config.hidden_size
-
- def get_input_embeddings(self) -> nn.Embedding:
- return self.embed_tokens
-
- def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
- # in qwen-vl, last dim is the same
- pixel_values = torch.cat([item.feature for item in items], dim=0).type(
- self.visual.dtype
- )
- image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
- assert pixel_values.dim() == 2, pixel_values.dim()
- assert image_grid_thw.dim() == 2, image_grid_thw.dim()
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
- return image_embeds
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- forward_batch: ForwardBatch,
- input_embeds: torch.Tensor = None,
- pp_proxy_tensors: Optional[PPProxyTensors] = None,
- input_deepstack_embeds: Optional[torch.Tensor] = None,
- ) -> Union[torch.Tensor, PPProxyTensors]:
- if self.pp_group.is_first_rank:
- if input_embeds is None:
- hidden_states = self.embed_tokens(input_ids)
- else:
- hidden_states = input_embeds
- residual = None
- else:
- assert pp_proxy_tensors is not None
- hidden_states = pp_proxy_tensors["hidden_states"]
- residual = pp_proxy_tensors["residual"]
-
- aux_hidden_states = []
- for layer_idx, layer in enumerate(
- self.layers[self.start_layer : self.end_layer]
- ):
- layer_idx = layer_idx + self.start_layer
- if layer_idx in self.layers_to_capture:
- aux_hidden_states.append(
- hidden_states + residual if residual is not None else hidden_states
- )
-
- hidden_states, residual = layer(
- positions,
- hidden_states,
- forward_batch,
- residual,
- )
-
- # process deepstack
- if input_deepstack_embeds is not None and layer_idx in range(3):
- sep = self.hidden_size * layer_idx
- hidden_states = (
- hidden_states
- + input_deepstack_embeds[:, sep : sep + self.hidden_size]
- )
-
- if not self.pp_group.is_last_rank:
- return PPProxyTensors(
- {
- "hidden_states": hidden_states,
- "residual": residual,
- }
- )
- else:
- if hidden_states.shape[0] != 0:
- if residual is None:
- hidden_states = self.norm(hidden_states)
- else:
- hidden_states, _ = self.norm(hidden_states, residual)
-
- if len(aux_hidden_states) == 0:
- return hidden_states
-
- return hidden_states, aux_hidden_states
-
-
-class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
- def __init__(
- self,
- *,
- config: Qwen3VLMoeConfig,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ):
- super(Qwen3VLForConditionalGeneration, self).__init__()
- self.config = config
-
- self.visual = Qwen3_VisionTransformer(
- config.vision_config,
- norm_eps=getattr(config, "rms_norm_eps", 1e-6),
- # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
- # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
- quant_config=quant_config,
- prefix=add_prefix("visual", prefix),
- )
-
- self.model = Qwen3MoeLLMModel(
- config=config,
- quant_config=quant_config,
- prefix=add_prefix("model", prefix),
- )
-
- if config.tie_word_embeddings:
- self.lm_head = self.model.embed_tokens
- else:
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=add_prefix("lm_head", prefix),
- )
- self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
-
- self.logits_processor = LogitsProcessor(config)
- self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
-
- # deepstack
- self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
- self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
-
- @property
- def use_deepstack(self) -> bool:
- return hasattr(self, "deepstack_visual_indexes")
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- forward_batch: ForwardBatch,
- get_embedding: bool = False,
- ):
- """Run forward pass for Qwen3-VL.
-
- Args:
- input_ids: Flattened (concatenated) input_ids corresponding to a
- batch.
- positions: Flattened (concatenated) position ids corresponding to a
- batch.
- **NOTE**: If mrope is enabled (default setting for Qwen2-VL
- opensource models), the shape will be `(3, seq_len)`,
- otherwise it will be `(seq_len,).
- (Use input_metadata.mrope_positions to replace it)
- """
- if self.is_mrope_enabled:
- positions = forward_batch.mrope_positions
-
- if not (
- forward_batch.forward_mode.is_decode()
- or not forward_batch.contains_image_inputs()
- ):
- if self.is_mrope_enabled:
- assert positions.ndim == 2 and positions.size(0) == 3, (
- "multimodal section rotary embedding requires "
- f"(3, seq_len) positions, but got {positions.size()}"
- )
-
- hidden_states = general_mm_embed_routine(
- input_ids=input_ids,
- forward_batch=forward_batch,
- language_model=self.model,
- multimodal_model=self,
- positions=positions,
- use_deepstack=self.use_deepstack,
- )
-
- if not get_embedding:
- return self.logits_processor(
- input_ids, hidden_states, self.lm_head, forward_batch
- )
- else:
- return self.pooler(hidden_states, forward_batch)
-
- def load_fused_expert_weights(
- self,
- name: str,
- params_dict: dict,
- loaded_weight: torch.Tensor,
- shard_id: str,
- num_experts: int,
- ):
- param = params_dict[name]
- # weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
- weight_loader = param.weight_loader
- ep_rank = get_tensor_model_parallel_rank()
- ep_size = get_moe_expert_parallel_world_size()
- if ep_size == 1:
- for expert_id in range(num_experts):
- curr_expert_weight = loaded_weight[expert_id]
- weight_loader(
- param,
- curr_expert_weight,
- name,
- shard_id,
- expert_id,
- )
- else:
- experts_per_ep = num_experts // ep_size
- start_expert = ep_rank * experts_per_ep
- end_expert = (
- (ep_rank + 1) * experts_per_ep
- if ep_rank != ep_size - 1
- else num_experts
- )
-
- for idx, expert_id in enumerate(range(start_expert, end_expert)):
- curr_expert_weight = loaded_weight[expert_id]
- weight_loader(
- param,
- curr_expert_weight,
- name,
- shard_id,
- idx,
- )
- return True
-
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- (".qkv_proj", ".q_proj", "q"),
- (".qkv_proj", ".k_proj", "k"),
- (".qkv_proj", ".v_proj", "v"),
- ("gate_up_proj", "up_proj", 1),
- ("gate_up_proj", "gate_proj", 0),
- ]
-
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
- ckpt_gate_proj_name="gate_proj",
- ckpt_down_proj_name="down_proj",
- ckpt_up_proj_name="up_proj",
- num_experts=self.config.num_experts,
- )
-
- # Skip loading extra parameters for GPTQ/modelopt models.
- ignore_suffixes = (
- ".bias",
- "_bias",
- ".k_scale",
- "_k_scale",
- ".v_scale",
- "_v_scale",
- ".weight_scale",
- "_weight_scale",
- ".input_scale",
- "_input_scale",
- )
-
- is_fused_expert = False
- fused_expert_params_mapping = [
- ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
- ("experts.w2_weight", "experts.down_proj", 0, "w2"),
- ]
-
- num_experts = self.config.num_experts
-
- # Cache params_dict to avoid repeated expensive traversal of model parameters
- if not hasattr(self, "_cached_params_dict"):
- self._cached_params_dict = dict(self.named_parameters())
- params_dict = self._cached_params_dict
- for name, loaded_weight in weights:
- if "language_model" in name:
- name = name.replace(r"model.language_model.", r"model.")
-
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if "experts.gate_up_proj" in name or "experts.down_proj" in name:
- is_fused_expert = True
- expert_params_mapping = fused_expert_params_mapping
-
- # Skip non-stacked layers and experts (experts handled below).
- if weight_name not in name:
- continue
- if "visual" in name:
- continue
-
- # We have mlp.experts[0].gate_proj in the checkpoint.
- # Since we handle the experts below in expert_params_mapping,
- # we need to skip here BEFORE we update the name, otherwise
- # name will be updated to mlp.experts[0].gate_up_proj, which
- # will then be updated below in expert_params_mapping
- # for mlp.experts[0].gate_gate_up_proj, which breaks load.
- if "mlp.experts" in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip loading extra parameters for GPTQ/modelopt models.
- if name.endswith(ignore_suffixes) and name not in params_dict:
- continue
- # [TODO] Skip layers that are on other devices (check if sglang has a similar function)
- # if is_pp_missing_parameter(name, self):
- # continue
-
- if name not in params_dict:
- continue
-
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- # Track if this is an expert weight to enable early skipping
- is_expert_weight = False
-
- for mapping in expert_params_mapping:
- param_name, weight_name, expert_id, shard_id = mapping
- if weight_name not in name:
- continue
- if "visual" in name:
- continue
- # Anyway, this is an expert weight and should not be
- # attempted to load as other weights later
- is_expert_weight = True
- name_mapped = name.replace(weight_name, param_name)
- if is_fused_expert:
- loaded_weight = loaded_weight.transpose(-1, -2) # no bias
- if "experts.gate_up_proj" in name:
- loaded_weight = loaded_weight.chunk(2, dim=-2)
- self.load_fused_expert_weights(
- name_mapped,
- params_dict,
- loaded_weight[0],
- "w1",
- num_experts,
- )
- self.load_fused_expert_weights(
- name_mapped,
- params_dict,
- loaded_weight[1],
- "w3",
- num_experts,
- )
- else:
- self.load_fused_expert_weights(
- name_mapped,
- params_dict,
- loaded_weight,
- shard_id,
- num_experts,
- )
- else:
- # Skip loading extra parameters for GPTQ/modelopt models.
- if (
- name_mapped.endswith(ignore_suffixes)
- and name_mapped not in params_dict
- ):
- continue
- param = params_dict[name_mapped]
- # We should ask the weight loader to return success or
- # not here since otherwise we may skip experts with
- # # other available replicas.
- weight_loader = param.weight_loader
- weight_loader(
- param,
- loaded_weight,
- name_mapped,
- shard_id=shard_id,
- expert_id=expert_id,
- )
- name = name_mapped
- break
- else:
- if is_expert_weight:
- # This is an expert weight but not mapped to this rank, skip all remaining processing
- continue
- if "visual" in name:
- # adapt to VisionAttention
- name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
- name = name.replace(r"model.visual.", r"visual.")
-
- # Skip loading extra parameters for GPTQ/modelopt models.
- if name.endswith(ignore_suffixes) and name not in params_dict:
- continue
-
- if name in params_dict.keys():
- param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- weight_loader(param, loaded_weight)
- else:
- logger.warning(f"Parameter {name} not found in params_dict")
-
- # TODO mimic deepseek
- # Lazy initialization of expert weights cache to avoid slowing down load_weights
- # if not hasattr(self, "routed_experts_weights_of_layer"):
- # self.routed_experts_weights_of_layer = {
- # layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
- # for layer_id in range(self.start_layer, self.end_layer)
- # if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
- # }
-
-
-EntryClass = Qwen3VLMoeForConditionalGeneration
diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py
index 14b327bd1..00499ce66 100644
--- a/python/sglang/srt/models/torch_native_llama.py
+++ b/python/sglang/srt/models/torch_native_llama.py
@@ -66,8 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
-tp_size: Optional[int] = None
-tp_rank: Optional[int] = None
+tp_size = get_tensor_model_parallel_world_size()
+tp_rank = get_tensor_model_parallel_rank()
def gate_up_proj_weight_loader(
@@ -341,13 +341,6 @@ class LlamaModel(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
-
- global tp_size, tp_rank
- if tp_size is None:
- tp_size = get_tensor_model_parallel_world_size()
- if tp_rank is None:
- tp_rank = get_tensor_model_parallel_rank()
-
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py
deleted file mode 100644
index f4c2a0e3e..000000000
--- a/python/sglang/srt/models/utils.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright 2023-2025 SGLang Team
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-import torch
-
-from sglang.srt.layers.radix_attention import RadixAttention
-from sglang.srt.model_executor.forward_batch_info import ForwardBatch
-from sglang.srt.utils import is_cuda
-
-_is_cuda = is_cuda()
-
-
-if _is_cuda:
- from sgl_kernel import FusedSetKVBufferArg
-
-
-def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
- """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
- return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
-
-
-def create_fused_set_kv_buffer_arg(
- value: torch.Tensor,
- layer: RadixAttention,
- forward_batch: ForwardBatch,
-):
- layer_id = layer.layer_id
- token_to_kv_pool = forward_batch.token_to_kv_pool
-
- k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
- v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
-
- return FusedSetKVBufferArg(
- value=value,
- k_buffer=k_buffer.view(k_buffer.shape[0], -1),
- v_buffer=v_buffer.view(v_buffer.shape[0], -1),
- k_scale=layer.k_scale,
- v_scale=layer.v_scale,
- cache_loc=forward_batch.out_cache_loc,
- )
diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py
index ef076ae09..e5da78368 100644
--- a/python/sglang/srt/multimodal/processors/base_processor.py
+++ b/python/sglang/srt/multimodal/processors/base_processor.py
@@ -234,14 +234,7 @@ class BaseMultimodalProcessor(ABC):
and isinstance(processor.image_processor, BaseImageProcessorFast)
and not self.server_args.disable_fast_image_processor
):
- if not _is_npu:
- kwargs["device"] = "cuda"
- elif processor.__class__.__name__ not in {
- "Qwen2_5_VLProcessor",
- "Qwen3VLProcessor",
- }:
- # Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
- kwargs["device"] = "npu"
+ kwargs["device"] = "cuda" if not _is_npu else "npu"
result = processor.__call__(
text=[input_text],
padding=True,
diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py
index ec5e574f4..facddfea5 100644
--- a/python/sglang/srt/multimodal/processors/qwen_vl.py
+++ b/python/sglang/srt/multimodal/processors/qwen_vl.py
@@ -12,8 +12,6 @@ from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
-from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
-from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
@@ -211,12 +209,7 @@ async def preprocess_video(
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
- models = [
- Qwen2VLForConditionalGeneration,
- Qwen2_5_VLForConditionalGeneration,
- Qwen3VLForConditionalGeneration,
- Qwen3VLMoeForConditionalGeneration,
- ]
+ models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
diff --git a/python/sglang/srt/patch_torch.py b/python/sglang/srt/patch_torch.py
index 6dc329a9d..8d90ce4c0 100644
--- a/python/sglang/srt/patch_torch.py
+++ b/python/sglang/srt/patch_torch.py
@@ -17,18 +17,10 @@ import torch
from packaging import version
from torch.multiprocessing import reductions
-from sglang.srt.utils import is_npu
-
-_is_npu = is_npu()
-
def monkey_patch_torch_reductions():
"""Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
- # Currently, NPU does not support UUID. This has been temporarily commented out, with support expected in the fourth quarter.
- if _is_npu:
- return
-
if hasattr(reductions, "_reduce_tensor_original"):
return
diff --git a/python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py b/python/sglang/srt/remote_instance_weight_loader_utils.py
similarity index 100%
rename from python/sglang/srt/model_loader/remote_instance_weight_loader_utils.py
rename to python/sglang/srt/remote_instance_weight_loader_utils.py
diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py
index 0bff4d397..c644a9d7e 100644
--- a/python/sglang/srt/sampling/sampling_params.py
+++ b/python/sglang/srt/sampling/sampling_params.py
@@ -19,6 +19,7 @@ from sglang.srt.utils import get_bool_env_var
_SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
+DEFAULT_SAMPLING_SEED = 42
class SamplingParams:
@@ -55,7 +56,7 @@ class SamplingParams:
custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None,
logit_bias: Optional[Dict[str, float]] = None,
- sampling_seed: int = 42,
+ sampling_seed: Optional[int] = None,
) -> None:
self.max_new_tokens = max_new_tokens
self.stop_strs = stop
@@ -83,6 +84,13 @@ class SamplingParams:
self.custom_params = custom_params
self.stream_interval = stream_interval
self.logit_bias = logit_bias
+ # Used for deterministic sampling
+ if (
+ get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE")
+ and sampling_seed is None
+ ):
+ # If deterministic inference is enabled and sampling_seed is not set, use the default seed
+ sampling_seed = DEFAULT_SAMPLING_SEED
self.sampling_seed = sampling_seed
# Process some special cases
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 701415390..31423c2df 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -19,6 +19,8 @@ import json
import logging
import os
import random
+import socket
+import sys
import tempfile
from typing import List, Literal, Optional, Union
@@ -51,6 +53,7 @@ from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
+
# Define constants
LOAD_FORMAT_CHOICES = [
"auto",
@@ -91,6 +94,7 @@ ATTENTION_BACKEND_CHOICES = [
"triton",
"torch_native",
"flex_attention",
+ "nsa",
# NVIDIA specific
"cutlass_mla",
"fa3",
@@ -100,6 +104,7 @@ ATTENTION_BACKEND_CHOICES = [
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
+ "hybrid_linear_attn",
# AMD specific
"aiter",
"wave",
@@ -116,6 +121,11 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
+NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang"]
+
+NSA_DEFAULT_PREFILL = "flashmla_prefill"
+NSA_DEFAULT_DECODE = "fa3"
+
# Allow external code to add more choices
def add_load_format_choices(choices):
@@ -167,7 +177,6 @@ class ServerArgs:
quantization: Optional[str] = None
quantization_param_path: Optional[str] = None
kv_cache_dtype: str = "auto"
- enable_fp32_lm_head: bool = False
# Memory and scheduling
mem_fraction_static: Optional[float] = None
@@ -212,8 +221,8 @@ class ServerArgs:
show_time_cost: bool = False
enable_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False
- tokenizer_metrics_custom_labels_header: str = "x-custom-labels"
- tokenizer_metrics_allowed_custom_labels: Optional[List[str]] = None
+ tokenizer_metrics_custom_labels_header: str = "x-customer-labels"
+ tokenizer_metrics_allowed_customer_labels: Optional[List[str]] = None
bucket_time_to_first_token: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
@@ -286,14 +295,14 @@ class ServerArgs:
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill"
- # For ngram only
- speculative_ngram_min_match_window_size: int = 1
- speculative_ngram_max_match_window_size: int = 12
- speculative_ngram_min_bfs_breadth: int = 1
- speculative_ngram_max_bfs_breadth: int = 10
- speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS"
- speculative_ngram_branch_length: int = 18
- speculative_ngram_capacity: int = 10 * 1000 * 1000
+ # For lookahead only
+ speculative_lookahead_min_match_window_size: int = 1
+ speculative_lookahead_max_match_window_size: int = 12
+ speculative_lookahead_min_bfs_breadth: int = 1
+ speculative_lookahead_max_bfs_breadth: int = 10
+ speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS"
+ speculative_lookahead_branch_length: int = 18
+ speculative_lookahead_capacity: int = 10 * 1000 * 1000
# Expert parallelism
ep_size: int = 1
@@ -325,10 +334,6 @@ class ServerArgs:
deepep_config: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
- # Mamba cache
- max_mamba_cache_size: Optional[int] = None
- mamba_ssm_dtype: str = "float32"
-
# Hierarchical cache
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
@@ -399,7 +404,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False
scheduler_recv_interval: int = 1
numa_node: Optional[List[int]] = None
- enable_deterministic_inference: bool = False
+ max_prefill_bs: Optional[int] = None
# Dynamic batch tokenizer
enable_dynamic_batch_tokenizer: bool = False
@@ -420,14 +425,16 @@ class ServerArgs:
disaggregation_decode_dp: Optional[int] = None
disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None
- disaggregation_decode_enable_offload_kvcache: bool = False
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
+
# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1
- # For model weight update and weight loading
+ # For model weight update
custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False
+
+ # Remote instance weight loading
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
@@ -436,80 +443,62 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
- def __post_init__(self):
- """
- Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
- """
- # Handle deprecated arguments.
- self._handle_deprecated_args()
+ # Mamba cache
+ max_mamba_cache_size: Optional[int] = None
+ mamba_ssm_dtype: str = "float32"
- # Set missing default values.
- self._handle_missing_default_values()
+ # For deterministic inference
+ enable_deterministic_inference: bool = False
- # Get GPU memory capacity, which is a common dependency for several configuration steps.
- gpu_mem = get_device_memory_capacity(self.device)
+ # NSA attention backend
+ nsa_prefill: str = NSA_DEFAULT_PREFILL
+ nsa_decode: str = NSA_DEFAULT_DECODE
- # Handle memory-related, chunked prefill, and CUDA graph batch size configurations.
- self._handle_gpu_memory_settings(gpu_mem)
-
- # Handle device-specific backends.
- self._handle_hpu_backends()
- self._handle_cpu_backends()
-
- # Apply model-specific adjustments.
- self._handle_model_specific_adjustments()
-
- # Set kernel backends.
- self._handle_sampling_backend()
- self._handle_attention_backend_compatibility()
- self._handle_page_size()
- self._handle_amd_specifics()
- self._handle_grammar_backend()
-
- # Handle data parallelism.
- self._handle_data_parallelism()
-
- # Handle MoE configurations.
- self._handle_moe_kernel_config()
- self._handle_deepep_moe()
- self._handle_eplb_and_dispatch()
- self._handle_expert_distribution_metrics()
-
- # Handle pipeline parallelism.
- self._handle_pipeline_parallelism()
-
- # Handle Hicache settings.
- self._handle_hicache()
-
- # Handle speculative decoding logic.
- self._handle_speculative_decoding()
-
- # Handle model loading format.
- self._handle_load_format()
-
- # Handle PD disaggregation.
- self._handle_disaggregation()
-
- # Validate tokenizer settings.
- self._handle_tokenizer_batching()
-
- # Propagate environment variables.
- self._handle_environment_variables()
-
- # Validate cache settings.
- self._handle_cache_compatibility()
-
- # Validate metrics labels.
- self._handle_metrics_labels()
-
- # Handle deterministic inference.
- self._handle_deterministic_inference()
-
- # Handle any other necessary validations.
- self._handle_other_validations()
+ # Deprecated arguments
+ enable_ep_moe: bool = False
+ enable_deepep_moe: bool = False
+ enable_flashinfer_cutlass_moe: bool = False
+ enable_flashinfer_cutedsl_moe: bool = False
+ enable_flashinfer_trtllm_moe: bool = False
+ enable_triton_kernel_moe: bool = False
+ enable_flashinfer_mxfp4_moe: bool = False
def _handle_deprecated_args(self):
- pass
+ if self.enable_ep_moe:
+ self.ep_size = self.tp_size
+ print_deprecated_warning(
+ "NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
+ )
+ if self.enable_deepep_moe:
+ self.moe_a2a_backend = "deepep"
+ print_deprecated_warning(
+ "NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
+ )
+ if self.enable_triton_kernel_moe:
+ self.moe_runner_backend = "triton_kernel"
+ print_deprecated_warning(
+ "NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
+ )
+ if self.enable_flashinfer_cutedsl_moe:
+ self.moe_runner_backend = "flashinfer_cutedsl"
+ print_deprecated_warning(
+ "NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead."
+ )
+ if self.enable_flashinfer_cutlass_moe:
+ self.moe_runner_backend = "flashinfer_cutlass"
+ print_deprecated_warning(
+ "NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
+ )
+ if self.enable_flashinfer_trtllm_moe:
+ self.moe_runner_backend = "flashinfer_trtllm"
+ print_deprecated_warning(
+ "NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
+ )
+ if self.enable_flashinfer_mxfp4_moe:
+ self.moe_runner_backend = "flashinfer_mxfp4"
+ print_deprecated_warning(
+ "NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead."
+ )
def _handle_missing_default_values(self):
if self.tokenizer_path is None:
@@ -521,174 +510,85 @@ class ServerArgs:
if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
- def _handle_gpu_memory_settings(self, gpu_mem):
- """
- Configure GPU memory-dependent settings including
- chunked_prefill_size, cuda_graph_max_bs, and mem_fraction_static.
-
- Here are our heuristics:
- - Set chunked_prefill_size and cuda_graph_max_bs based on the GPU memory capacity.
- This is because GPUs with more memory are generally more powerful, we need to use a larger
- chunked_prefill_size and a larger cuda_graph_max_bs to fully utilize the GPU.
- - Then set mem_fraction_static based on chunked_prefill_size and cuda_graph_max_bs.
-
- GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
-
- The argument mem_fraction_static is defined as (model weights + KV cache pool) / GPU memory capacity,
- or equivalently, mem_fraction_static = (GPU memory capacity - activations - cuda graph buffers) / GPU memory capacity.
-
- In order to compute mem_fraction_static, we need to estimate the size of activations and cuda graph buffers.
- The activation memory is proportional to the chunked_prefill_size.
- The cuda graph memory is proportional to the cuda_graph_max_bs.
- We use reserved_mem = chunked_prefill_size * 1.5 + cuda_graph_max_bs * 2 to estimate the size of activations and cuda graph buffers in GB.
- and set mem_fraction_static = (GPU memory capacity - reserved_mem) / GPU memory capacity.
-
- The coefficient 1.5 is a heuristic value, in the future, we can do better estimation by looking at the model types, hidden sizes or even do a dummy run.
- """
- if gpu_mem is not None:
- if gpu_mem < 20 * 1024:
- # T4, 4080
- # (chunked_prefill_size 2k, cuda_graph_max_bs 8)
- if self.chunked_prefill_size is None:
- self.chunked_prefill_size = 2048
- if self.cuda_graph_max_bs is None:
- self.cuda_graph_max_bs = 8
- elif gpu_mem < 35 * 1024:
- # A10, 4090, 5090
- # (chunked_prefill_size 2k, cuda_graph_max_bs 16 if tp < 4 else 80)
- if self.chunked_prefill_size is None:
- self.chunked_prefill_size = 2048
- if self.cuda_graph_max_bs is None:
- # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM < 35GB, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance.
- # However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs
- # from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
- if self.tp_size < 4:
- self.cuda_graph_max_bs = 16
- else:
- self.cuda_graph_max_bs = 80
- elif gpu_mem < 60 * 1024:
- # A100 (40GB), L40,
- # (chunked_prefill_size 4k, cuda_graph_max_bs 32 if tp < 4 else 160)
- if self.chunked_prefill_size is None:
- self.chunked_prefill_size = 4096
- if self.cuda_graph_max_bs is None:
- if self.tp_size < 4:
- self.cuda_graph_max_bs = 32
- else:
- self.cuda_graph_max_bs = 160
- elif gpu_mem < 90 * 1024:
- # H100, A100
- # (chunked_prefill_size 8k, cuda_graph_max_bs 256 if tp < 4 else 512)
- if self.chunked_prefill_size is None:
- self.chunked_prefill_size = 8192
- if self.cuda_graph_max_bs is None:
- if self.tp_size < 4:
- self.cuda_graph_max_bs = 256
- else:
- self.cuda_graph_max_bs = 512
- elif gpu_mem < 160 * 1024:
- # H20, H200
- # (chunked_prefill_size 8k, cuda_graph_max_bs 256 if tp < 4 else 512)
- if self.chunked_prefill_size is None:
- self.chunked_prefill_size = 8192
- if self.cuda_graph_max_bs is None:
- if self.tp_size < 4:
- self.cuda_graph_max_bs = 256
- else:
- self.cuda_graph_max_bs = 512
- else:
- # B200, MI300
- # (chunked_prefill_size 16k, cuda_graph_max_bs 512)
- if self.chunked_prefill_size is None:
- self.chunked_prefill_size = 16384
- if self.cuda_graph_max_bs is None:
- self.cuda_graph_max_bs = 512
- else:
- # Fallback defaults when gpu_mem is None
- if self.chunked_prefill_size is None:
- self.chunked_prefill_size = 4096
- if self.cuda_graph_max_bs is None:
- self.cuda_graph_max_bs = 160
-
- # Set cuda graph batch sizes
- if self.cuda_graph_bs is None:
- self.cuda_graph_bs = self._generate_cuda_graph_batch_sizes()
- else:
- self.cuda_graph_max_bs = max(self.cuda_graph_bs)
-
+ def _handle_mem_fraction_static(self, gpu_mem):
if self.mem_fraction_static is None:
- # Constant meta data (e.g., from attention backend)
- reserved_mem = 512
- # For activation during large prefill
- if self.chunked_prefill_size > 0:
- reserved_mem += max(self.chunked_prefill_size, 2048) * 1.5
+ if gpu_mem is not None:
+ # GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
+ # mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity.
+
+ # We want mem_fraction_static to be as large as possible but still has enough room
+ # for activations and cuda graph buffers. We use the following heuristic to
+ # compute the needed size for activations and cuda graph buffers:
+ # - The size of the activation depends on the chunked_prefill_size and model size.
+ # - The size of cuda graph buffers depends on the cuda graph capture range and model size.
+ # For GPUs with more memory, we use a larger chunked_prefill_size and
+ # capture more cuda graphs, so they need to reserve more memory.
+ parallel_size = self.tp_size * self.pp_size
+
+ if gpu_mem < 20 * 1024:
+ # T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
+ reserved_mem = (2.8 + parallel_size / 10) * 1024
+ elif gpu_mem < 35 * 1024:
+ # A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
+ reserved_mem = (2.8 + parallel_size / 10) * 1024
+ elif gpu_mem < 90 * 1024:
+ # H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160)
+ reserved_mem = (9.5 + parallel_size / 2) * 1024
+ elif gpu_mem < 100 * 1024:
+ # H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
+ reserved_mem = (12 + parallel_size / 2) * 1024
+ elif gpu_mem < 160 * 1024:
+ # H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
+ reserved_mem = (12 + parallel_size / 2) * 1024
+ else:
+ # B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
+ reserved_mem = 32 * 1024
+
+ # draft model and larger cuda graph buffers
+ if self.speculative_algorithm is not None:
+ if self.speculative_algorithm == "STANDALONE":
+ # Standalone speculative decoding needs more memory than other speculative
+ # decoding algorithms since the draft model is typically larger.
+ reserved_mem += 6 * 1024
+ elif self.speculative_algorithm != "LOOKAHEAD":
+ reserved_mem += 2 * 1024
+ if self.enable_dp_attention:
+ reserved_mem += 4 * 1024
+
+ self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3)
else:
- reserved_mem += max(self.max_prefill_tokens, 2048) * 1.5
- # For cuda graphs
- reserved_mem += self.cuda_graph_max_bs * 2
- # Some adjustments for large parallel size
- reserved_mem += self.tp_size * self.pp_size / 8 * 1024
+ self.mem_fraction_static = 0.88
- if self.enable_dp_attention:
- # DP attention needs more padding for some operations
- reserved_mem += self.cuda_graph_max_bs * self.dp_size * 3
-
- # DP attention uses much more memory for large cuda graph max bs,
- # likely due to some inefficiencies in torch allocator or our implementation.
- # So we need to reserve more memory.
- if self.cuda_graph_max_bs > 300:
- reserved_mem += self.cuda_graph_max_bs * self.dp_size * 1.5
-
- if gpu_mem > 60 * 1024:
- reserved_mem = max(reserved_mem, 10 * 1024)
-
- if self.speculative_algorithm is not None:
- if self.speculative_algorithm == "STANDALONE":
- # standalonedraft model and cuda graphs
- reserved_mem += 6 * 1024
- elif self.speculative_algorithm != "NGRAM":
- # eagle draft models and cuda graphs
- reserved_mem += 2 * 1024
-
- self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3)
-
- # Lazy init to avoid circular import
- # Multimodal models need more memory for the image processor
+ # Lazy init to avoid circular import.
from sglang.srt.configs.model_config import ModelConfig
model_config = ModelConfig.from_server_args(self)
if model_config.is_multimodal:
self.adjust_mem_fraction_for_vlm(model_config)
- def _generate_cuda_graph_batch_sizes(self):
- """
- Generate the list of batch sizes for CUDA graph capture based on cuda_graph_max_bs.
- This integrates the logic from cuda_graph_runner.py.
- """
- # Handle disable_cuda_graph_padding as the first condition for both spec and non-spec
- if self.disable_cuda_graph_padding:
- capture_bs = list(range(1, self.cuda_graph_max_bs + 1))
- elif self.speculative_algorithm is None:
- # Normal case: [1, 2, 4, 8, 12] + list(range(16, 257, 8)) + list(range(272, 512, 16)) + list(range(512, cuda_graph_max_bs + 1))
- capture_bs = (
- [1, 2, 4, 8, 12]
- + list(range(16, 257, 8))
- + list(range(272, 512, 16))
- + list(range(512, self.cuda_graph_max_bs + 1))
- )
- else:
- # Spec decoding case: list(range(1, 9, 1)) + list(range(10, 33, 2)) + list(range(40, 64, 4)) + list(range(72, 257, 8))
- capture_bs = (
- list(range(1, 9, 1))
- + list(range(10, 33, 2))
- + list(range(40, 64, 4))
- + list(range(72, 257, 8))
- + list(range(272, self.cuda_graph_max_bs + 1, 16))
- )
+ def _handle_chunked_prefill_size(self, gpu_mem):
+ if self.chunked_prefill_size is None:
+ if gpu_mem is not None:
+ # A10, L40, 4090
+ if gpu_mem < 35 * 1024:
+ self.chunked_prefill_size = 2048
+ # H100, H200, A100, H20
+ elif gpu_mem < 160 * 1024:
+ self.chunked_prefill_size = 8192
+ # B200, MI300
+ else:
+ self.chunked_prefill_size = 16384
+ else:
+ self.chunked_prefill_size = 4096
- capture_bs = [bs for bs in capture_bs if bs <= self.cuda_graph_max_bs]
-
- return capture_bs
+ def _handle_cuda_graph_max_bs(self, gpu_mem):
+ # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
+ if self.cuda_graph_max_bs is None:
+ if gpu_mem is not None and gpu_mem < 35 * 1024:
+ if self.tp_size < 4:
+ self.cuda_graph_max_bs = 8
+ else:
+ self.cuda_graph_max_bs = 80
def _handle_hpu_backends(self):
if self.device == "hpu":
@@ -701,84 +601,6 @@ class ServerArgs:
self.attention_backend = "intel_amx"
self.sampling_backend = "pytorch"
- def _handle_model_specific_adjustments(self):
- if parse_connector_type(self.model_path) == ConnectorType.INSTANCE:
- return
-
- hf_config = self.get_hf_config()
- model_arch = hf_config.architectures[0]
- if model_arch in ["GptOssForCausalLM"]:
- if self.attention_backend is None:
- if is_cuda() and is_sm100_supported():
- self.attention_backend = "trtllm_mha"
- elif is_cuda() and is_sm90_supported():
- self.attention_backend = "fa3"
- else:
- self.attention_backend = "triton"
- supported_backends = ["triton", "trtllm_mha", "fa3"]
- logger.info(
- f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
- )
- assert (
- self.attention_backend in supported_backends
- ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
-
- if is_sm100_supported():
- if not self.enable_dp_attention:
- self.enable_flashinfer_allreduce_fusion = True
- logger.info(
- "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
- )
- quantization_config = getattr(hf_config, "quantization_config", None)
- is_mxfp4_quant_format = (
- quantization_config is not None
- and quantization_config.get("quant_method") == "mxfp4"
- )
-
- if is_sm100_supported() and is_mxfp4_quant_format:
- self.moe_runner_backend = "flashinfer_mxfp4"
- logger.warning(
- "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
- )
- else:
- if self.moe_runner_backend == "triton_kernel":
- assert (
- self.ep_size == 1
- ), "Triton kernel MoE is only supported when ep_size == 1"
- if (
- self.moe_runner_backend == "auto"
- and self.ep_size == 1
- and is_triton_kernels_available()
- ):
- self.moe_runner_backend = "triton_kernel"
- logger.warning(
- "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
- )
- self.disable_hybrid_swa_memory = True
- if is_mxfp4_quant_format:
- # use bf16 for mxfp4 triton kernels
- self.dtype = "bfloat16"
-
- elif "Llama4" in model_arch and self.device != "cpu":
- assert self.attention_backend in {
- "fa3",
- "aiter",
- "triton",
- }, "fa3, aiter, or triton is required for Llama4 model"
- elif model_arch in [
- "Gemma2ForCausalLM",
- "Gemma3ForCausalLM",
- "Gemma3ForConditionalGeneration",
- "Gemma3nForCausalLM",
- "Gemma3nForConditionalGeneration",
- ]:
- # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
- # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
- logger.warning(
- f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
- )
- self.disable_hybrid_swa_memory = True
-
def _handle_sampling_backend(self):
if self.sampling_backend is None:
self.sampling_backend = (
@@ -801,7 +623,7 @@ class ServerArgs:
self.speculative_algorithm is None
), "Speculative decoding is currently not supported with Flex Attention backend"
- if is_npu() and self.attention_backend in ["ascend"]:
+ if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]:
logger.warning(
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
)
@@ -964,15 +786,8 @@ class ServerArgs:
def _handle_hicache(self):
if self.hicache_storage_backend == "mooncake":
- if self.hicache_mem_layout == "layer_first":
- if self.hicache_io_backend == "direct":
- self.hicache_mem_layout = "page_first_direct"
- elif self.hicache_io_backend == "kernel":
- self.hicache_mem_layout = "page_first"
- logger.warning(
- f"Mooncake storage backend does not support layer_first layout, "
- f"switching to {self.hicache_mem_layout} layout for {self.hicache_io_backend} io backend"
- )
+ self.hicache_io_backend = "kernel"
+ self.hicache_mem_layout = "page_first"
if self.hicache_mem_layout == "page_first_direct":
if self.hicache_io_backend != "direct":
@@ -1007,6 +822,7 @@ class ServerArgs:
model_arch = self.get_hf_config().architectures[0]
if model_arch in [
+ "DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
"BailingMoeForCausalLM",
@@ -1058,23 +874,23 @@ class ServerArgs:
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
)
- if self.speculative_algorithm == "NGRAM":
+ if self.speculative_algorithm == "LOOKAHEAD":
if not self.device.startswith("cuda"):
raise ValueError(
- "Ngram speculative decoding only supports CUDA device."
+ "Lookahead speculative decoding only supports CUDA device."
)
if self.max_running_requests is None:
self.max_running_requests = 48
self.disable_overlap_schedule = True
self.enable_mixed_chunk = False
- self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth
+ self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth
if self.speculative_num_draft_tokens is None:
self.speculative_num_draft_tokens = (
- self.speculative_ngram_max_match_window_size
+ self.speculative_lookahead_max_match_window_size
)
logger.warning(
"The overlap scheduler and mixed chunked prefill are disabled because of "
- "using ngram speculative decoding."
+ "using lookahead speculative decoding."
)
if (
@@ -1086,9 +902,9 @@ class ServerArgs:
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
)
if self.enable_dp_attention:
- # TODO: support dp attention for ngram speculative decoding
+ # TODO: support dp attention for lookahead speculative decoding
raise ValueError(
- "Currently ngram speculative decoding does not support dp attention."
+ "Currently lookahead speculative decoding does not support dp attention."
)
def _handle_load_format(self):
@@ -1166,55 +982,120 @@ class ServerArgs:
"and cannot be used at the same time. Please use only one of them."
)
- if (
- self.disaggregation_decode_enable_offload_kvcache
- and self.disaggregation_mode != "decode"
- ):
- raise ValueError(
- "The argument disaggregation-decode-enable-offload-kvcache is only supported for decode side."
- )
-
def _handle_metrics_labels(self):
if (
not self.tokenizer_metrics_custom_labels_header
- and self.tokenizer_metrics_allowed_custom_labels
+ and self.tokenizer_metrics_allowed_customer_labels
):
raise ValueError(
- "Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-custom-labels."
+ "Please set --tokenizer-metrics-custom-labels-header when setting --tokenizer-metrics-allowed-customer-labels."
)
def _handle_deterministic_inference(self):
if self.enable_deterministic_inference:
- # Check sampling backend
+ import importlib
+
+ if not importlib.util.find_spec("batch_invariant_ops"):
+ raise ValueError(
+ "batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
+ )
+
+ # Check some settings
self.sampling_backend = "pytorch"
logger.warning(
"Sampling backend is set to pytorch for deterministic inference."
)
-
- # Check attention backend
+ # Currently, only FA3 supports radix cache. Support for other backends is in progress
+ if self.attention_backend != "fa3":
+ self.disable_radix_cache = True
+ logger.warning(
+ "Currently radix cache is disabled for deterministic inference. It will be supported in the future."
+ )
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
)
- # Currently, only FA3 supports radix cache. Support for other backends is in progress
- if self.attention_backend != "fa3":
- self.disable_radix_cache = True
- logger.warning(
- f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
- )
-
- # Check TP size
- if self.tp_size > 1:
- os.environ["NCCL_ALGO"] = "allreduce:tree"
- self.disable_custom_all_reduce = True
- logger.warning(
- "NCCL_ALGO is set to 'allreduce:tree' and custom all reduce is disabled for deterministic inference when TP size > 1."
- )
-
def _handle_other_validations(self):
pass
+ def __post_init__(self):
+ """
+ Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
+ """
+ # Step 1: Handle deprecated arguments.
+ self._handle_deprecated_args()
+
+ # Step 2: Set missing default values.
+ self._handle_missing_default_values()
+
+ # Get GPU memory capacity, which is a common dependency for several configuration steps.
+ gpu_mem = get_device_memory_capacity(self.device)
+
+ # Step 3: Handle memory-related configurations.
+ self._handle_mem_fraction_static(gpu_mem)
+ self._handle_chunked_prefill_size(gpu_mem)
+
+ # Step 4: Handle CUDA graph settings.
+ self._handle_cuda_graph_max_bs(gpu_mem)
+
+ # Step 5: Handle device-specific backends.
+ self._handle_hpu_backends()
+ self._handle_cpu_backends()
+
+ # Step 6: Apply model-specific adjustments.
+ if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
+ self.model_specific_adjustments()
+
+ # Step 7: Set kernel backends.
+ self._handle_sampling_backend()
+ self._handle_attention_backend_compatibility()
+ self._handle_page_size()
+ self._handle_amd_specifics()
+ self._handle_grammar_backend()
+
+ # Step 8: Handle data parallelism.
+ self._handle_data_parallelism()
+
+ # Step 9: Handle MoE configurations.
+ self._handle_moe_kernel_config()
+ self._handle_deepep_moe()
+ self._handle_eplb_and_dispatch()
+ self._handle_expert_distribution_metrics()
+
+ # Step 10: Handle pipeline parallelism.
+ self._handle_pipeline_parallelism()
+
+ # Step 11: Handle Hicache settings.
+ self._handle_hicache()
+
+ # Step 12: Handle speculative decoding logic.
+ self._handle_speculative_decoding()
+
+ # Step 13: Handle model loading format.
+ self._handle_load_format()
+
+ # Step 14: Handle PD disaggregation.
+ self._handle_disaggregation()
+
+ # Step 15: Validate tokenizer settings.
+ self._handle_tokenizer_batching()
+
+ # Step 16: Propagate environment variables.
+ self._handle_environment_variables()
+
+ # Step 17: Validate cache settings.
+ self._handle_cache_compatibility()
+
+ # Step 18: Validate metrics labels.
+ self._handle_metrics_labels()
+
+ # Step 19: Handle deterministic inference.
+ self._handle_deterministic_inference()
+
+ # Step 20: Handle any other necessary validations.
+ self._handle_other_validations()
+
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer
@@ -1225,6 +1106,24 @@ class ServerArgs:
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
required=True,
)
+ parser.add_argument(
+ "--remote-instance-weight-loader-seed-instance-ip",
+ type=str,
+ default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
+ help="The ip of the seed instance for loading weights from remote instance.",
+ )
+ parser.add_argument(
+ "--remote-instance-weight-loader-seed-instance-service-port",
+ type=int,
+ default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
+ help="The service port of the seed instance for loading weights from remote instance.",
+ )
+ parser.add_argument(
+ "--remote-instance-weight-loader-send-weights-group-ports",
+ type=json_list_type,
+ default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
+ help="The communication group ports for loading weights from remote instance.",
+ )
parser.add_argument(
"--tokenizer-path",
type=str,
@@ -1393,11 +1292,6 @@ class ServerArgs:
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
- parser.add_argument(
- "--enable-fp32-lm-head",
- action="store_true",
- help="If set, the LM head outputs (logits) are in FP32.",
- )
# Memory and scheduling
parser.add_argument(
@@ -1637,16 +1531,16 @@ class ServerArgs:
"--tokenizer-metrics-custom-labels-header",
type=str,
default=ServerArgs.tokenizer_metrics_custom_labels_header,
- help="Specify the HTTP header for passing custom labels for tokenizer metrics.",
+ help="Specify the HTTP header for passing customer labels for tokenizer metrics.",
)
parser.add_argument(
- "--tokenizer-metrics-allowed-custom-labels",
+ "--tokenizer-metrics-allowed-customer-labels",
type=str,
nargs="+",
- default=ServerArgs.tokenizer_metrics_allowed_custom_labels,
- help="The custom labels allowed for tokenizer metrics. The labels are specified via a dict in "
+ default=ServerArgs.tokenizer_metrics_allowed_customer_labels,
+ help="The customer labels allowed for tokenizer metrics. The labels are specified via a dict in "
"'--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': "
- "'value2'} is allowed if '--tokenizer-metrics-allowed-custom-labels label1 label2' is set.",
+ "'value2'} is allowed if '--tokenizer-metrics-allowed-labels label1 label2' is set.",
)
parser.add_argument(
"--bucket-time-to-first-token",
@@ -1678,8 +1572,8 @@ class ServerArgs:
bucket_rule = (
"Supports 3 rule types: 'default' uses predefined buckets; 'tse ' "
"generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets "
- "[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'custom "
- " ...' uses custom bucket values (e.g., 'custom 10 50 100 500')."
+ "[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer "
+ " ...' uses custom bucket values (e.g., 'customer 10 50 100 500')."
)
parser.add_argument(
"--prompt-tokens-buckets",
@@ -1951,7 +1845,7 @@ class ServerArgs:
parser.add_argument(
"--mm-attention-backend",
type=str,
- choices=["sdpa", "fa3", "triton_attn", "ascend_attn"],
+ choices=["sdpa", "fa3", "triton_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
@@ -1960,7 +1854,7 @@ class ServerArgs:
parser.add_argument(
"--speculative-algorithm",
type=str,
- choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"],
+ choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"],
help="Speculative algorithm.",
)
parser.add_argument(
@@ -2020,49 +1914,49 @@ class ServerArgs:
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_mode,
)
- # Ngram speculative decoding
+ # Lookahead speculative decoding
parser.add_argument(
- "--speculative-ngram-min-match-window-size",
+ "--speculative-lookahead-min-match-window-size",
type=int,
- default=ServerArgs.speculative_ngram_min_match_window_size,
- help="The minimum window size for pattern matching in ngram speculative decoding.",
+ default=ServerArgs.speculative_lookahead_min_match_window_size,
+ help="The minimum window size for pattern matching in lookahead speculative decoding.",
)
parser.add_argument(
- "--speculative-ngram-max-match-window-size",
+ "--speculative-lookahead-max-match-window-size",
type=int,
- default=ServerArgs.speculative_ngram_max_match_window_size,
- help="The maximum window size for pattern matching in ngram speculative decoding.",
+ default=ServerArgs.speculative_lookahead_max_match_window_size,
+ help="The maximum window size for pattern matching in lookahead speculative decoding.",
)
parser.add_argument(
- "--speculative-ngram-min-bfs-breadth",
+ "--speculative-lookahead-min-bfs-breadth",
type=int,
- default=ServerArgs.speculative_ngram_min_bfs_breadth,
- help="The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding.",
+ default=ServerArgs.speculative_lookahead_min_bfs_breadth,
+ help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
)
parser.add_argument(
- "--speculative-ngram-max-bfs-breadth",
+ "--speculative-lookahead-max-bfs-breadth",
type=int,
- default=ServerArgs.speculative_ngram_max_bfs_breadth,
- help="The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding.",
+ default=ServerArgs.speculative_lookahead_max_bfs_breadth,
+ help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
)
parser.add_argument(
- "--speculative-ngram-match-type",
+ "--speculative-lookahead-match-type",
type=str,
choices=["BFS", "PROB"],
- default=ServerArgs.speculative_ngram_match_type,
+ default=ServerArgs.speculative_lookahead_match_type,
help="The match type for cache tree.",
)
parser.add_argument(
- "--speculative-ngram-branch-length",
+ "--speculative-lookahead-branch-length",
type=int,
- default=ServerArgs.speculative_ngram_branch_length,
- help="The branch length for ngram speculative decoding.",
+ default=ServerArgs.speculative_lookahead_branch_length,
+ help="The branch length for lookahead speculative decoding.",
)
parser.add_argument(
- "--speculative-ngram-capacity",
+ "--speculative-lookahead-capacity",
type=int,
- default=ServerArgs.speculative_ngram_capacity,
- help="The cache capacity for ngram speculative decoding.",
+ default=ServerArgs.speculative_lookahead_capacity,
+ help="The cache capacity for lookahead speculative decoding.",
)
# Expert parallelism
@@ -2256,12 +2150,9 @@ class ServerArgs:
parser.add_argument(
"--hicache-storage-backend",
type=str,
- choices=["file", "mooncake", "hf3fs", "nixl", "aibrix", "dynamic"],
+ choices=["file", "mooncake", "hf3fs", "nixl"],
default=ServerArgs.hicache_storage_backend,
- help="The storage backend for hierarchical KV cache. "
- "Built-in backends: file, mooncake, hf3fs, nixl, aibrix. "
- "For dynamic backend, use --hicache-storage-backend-extra-config to specify: "
- "backend_name (custom name), module_path (Python module path), class_name (backend class name).",
+ help="The storage backend for hierarchical KV cache.",
)
parser.add_argument(
"--hicache-storage-prefetch-policy",
@@ -2571,6 +2462,12 @@ class ServerArgs:
nargs="+",
help="Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess.",
)
+ parser.add_argument(
+ "--max-prefill-bs",
+ type=int,
+ default=ServerArgs.max_prefill_bs,
+ help="The maximum batch size for prefill requests.",
+ )
# Debug tensor dumps
parser.add_argument(
@@ -2661,11 +2558,6 @@ class ServerArgs:
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
)
- parser.add_argument(
- "--disaggregation-decode-enable-offload-kvcache",
- action="store_true",
- help="Enable async KV cache offloading on decode server (PD mode).",
- )
parser.add_argument(
"--num-reserved-decode-tokens",
type=int,
@@ -2692,24 +2584,6 @@ class ServerArgs:
action="store_true",
help="Disable mmap while loading weight using safetensors.",
)
- parser.add_argument(
- "--remote-instance-weight-loader-seed-instance-ip",
- type=str,
- default=ServerArgs.remote_instance_weight_loader_seed_instance_ip,
- help="The ip of the seed instance for loading weights from remote instance.",
- )
- parser.add_argument(
- "--remote-instance-weight-loader-seed-instance-service-port",
- type=int,
- default=ServerArgs.remote_instance_weight_loader_seed_instance_service_port,
- help="The service port of the seed instance for loading weights from remote instance.",
- )
- parser.add_argument(
- "--remote-instance-weight-loader-send-weights-group-ports",
- type=json_list_type,
- default=ServerArgs.remote_instance_weight_loader_send_weights_group_ports,
- help="The communication group ports for loading weights from remote instance.",
- )
# For PD-Multiplexing
parser.add_argument(
@@ -2732,48 +2606,56 @@ class ServerArgs:
help="Enable deterministic inference mode with batch invariant ops.",
)
+ # For NSA models
+ parser.add_argument(
+ "--nsa-prefill",
+ default=NSA_DEFAULT_PREFILL,
+ type=str,
+ choices=NSA_CHOICES,
+ )
+
+ parser.add_argument(
+ "--nsa-decode",
+ default=NSA_DEFAULT_DECODE,
+ type=str,
+ choices=NSA_CHOICES,
+ )
+
# Deprecated arguments
parser.add_argument(
"--enable-ep-moe",
- action=DeprecatedAction,
- help="NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead.",
+ action="store_true",
+ help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
)
parser.add_argument(
"--enable-deepep-moe",
- action=DeprecatedAction,
- help="NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead.",
+ action="store_true",
+ help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
- action=DeprecatedAction,
- help="NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead.",
+ action="store_true",
+ help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-cutedsl-moe",
- action=DeprecatedAction,
- help="NOTE: --enable-flashinfer-cutedsl-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutedsl' instead.",
+ action="store_true",
+ help="(Deprecated) Enable FlashInfer CuteDSL MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
- action=DeprecatedAction,
- help="NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead.",
+ action="store_true",
+ help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
)
parser.add_argument(
"--enable-triton-kernel-moe",
- action=DeprecatedAction,
- help="NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead.",
+ action="store_true",
+ help="(Deprecated) Use triton moe grouped gemm kernel.",
)
parser.add_argument(
"--enable-flashinfer-mxfp4-moe",
- action=DeprecatedAction,
- help="NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead.",
- )
-
- # Configuration file support
- parser.add_argument(
- "--config",
- type=str,
- help="Read CLI options from a config file. Must be a YAML file with configuration options.",
+ action="store_true",
+ help="(Deprecated) Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
)
@classmethod
@@ -2967,8 +2849,8 @@ class ServerArgs:
assert rule in [
"tse",
"default",
- "custom",
- ], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'custom'"
+ "customer",
+ ], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'customer'"
if rule == "tse":
assert (
@@ -2991,20 +2873,116 @@ class ServerArgs:
len(buckets_rule) == 1
), f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}"
- elif rule == "custom":
+ elif rule == "customer":
assert (
len(buckets_rule) >= 2
- ), f"{arg_name} custom rule requires at least one bucket value: ['custom', value1, ...]"
+ ), f"{arg_name} customer rule requires at least one bucket value: ['customer', value1, ...]"
try:
bucket_values = [float(x) for x in buckets_rule[1:]]
except ValueError:
- assert False, f"{arg_name} custom rule bucket values must be numeric"
+ assert False, f"{arg_name} customer rule bucket values must be numeric"
assert len(set(bucket_values)) == len(
bucket_values
- ), f"{arg_name} custom rule bucket values should not contain duplicates"
+ ), f"{arg_name} customer rule bucket values should not contain duplicates"
assert all(
val >= 0 for val in bucket_values
- ), f"{arg_name} custom rule bucket values should be non-negative"
+ ), f"{arg_name} customer rule bucket values should be non-negative"
+
+ def model_specific_adjustments(self):
+ from sglang.srt.configs.model_config import is_deepseek_nsa
+
+ hf_config = self.get_hf_config()
+ model_arch = hf_config.architectures[0]
+ if model_arch in ["GptOssForCausalLM"]:
+ if self.attention_backend is None:
+ if is_cuda() and is_sm100_supported():
+ self.attention_backend = "trtllm_mha"
+ elif is_cuda() and is_sm90_supported():
+ self.attention_backend = "fa3"
+ else:
+ self.attention_backend = "triton"
+ supported_backends = ["triton", "trtllm_mha", "fa3"]
+ logger.info(
+ f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
+ )
+ assert (
+ self.attention_backend in supported_backends
+ ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
+
+ if is_sm100_supported():
+ if not self.enable_dp_attention:
+ self.enable_flashinfer_allreduce_fusion = True
+ logger.info(
+ "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
+ )
+ quantization_config = getattr(hf_config, "quantization_config", None)
+ is_mxfp4_quant_format = (
+ quantization_config is not None
+ and quantization_config.get("quant_method") == "mxfp4"
+ )
+
+ if is_sm100_supported() and is_mxfp4_quant_format:
+ self.moe_runner_backend = "flashinfer_mxfp4"
+ logger.warning(
+ "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
+ )
+ else:
+ if self.moe_runner_backend == "triton_kernel":
+ assert (
+ self.ep_size == 1
+ ), "Triton kernel MoE is only supported when ep_size == 1"
+ if (
+ self.moe_runner_backend == "auto"
+ and self.ep_size == 1
+ and is_triton_kernels_available()
+ ):
+ self.moe_runner_backend = "triton_kernel"
+ logger.warning(
+ "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
+ )
+ self.disable_hybrid_swa_memory = True
+ if is_mxfp4_quant_format:
+ # use bf16 for mxfp4 triton kernels
+ self.dtype = "bfloat16"
+
+ elif "Llama4" in model_arch and self.device != "cpu":
+ assert self.attention_backend in {
+ "fa3",
+ "aiter",
+ "triton",
+ }, "fa3, aiter, or triton is required for Llama4 model"
+ elif model_arch in [
+ "Gemma2ForCausalLM",
+ "Gemma3ForCausalLM",
+ "Gemma3ForConditionalGeneration",
+ "Gemma3nForCausalLM",
+ "Gemma3nForConditionalGeneration",
+ ]:
+ # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
+ # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
+ logger.warning(
+ f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
+ )
+ self.disable_hybrid_swa_memory = True
+ elif is_deepseek_nsa(hf_config):
+ if (
+ self.attention_backend is None
+ and self.prefill_attention_backend is None
+ and self.decode_attention_backend is None
+ ):
+ self.attention_backend = "nsa"
+ logger.warning("Set nsa attention backend for DeepSeek NSA.")
+
+ if not is_npu():
+ self.enable_dp_attention = True
+ self.dp_size = self.tp_size
+ logger.warning("DP attention is enabled for DeepSeek NSA.")
+
+ self.page_size = 64
+ logger.warning("Setting page size to 64 for DeepSeek NSA.")
+
+ self.max_prefill_bs = 1
+ logger.warning("Setting maximum prefill batch size to 1 for DeepSeek NSA.")
def adjust_mem_fraction_for_vlm(self, model_config):
vision_config = getattr(model_config.hf_config, "vision_config", None)
@@ -3056,26 +3034,6 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
Returns:
The server arguments.
"""
- # Import here to avoid circular imports
- from sglang.srt.server_args_config_parser import ConfigArgumentMerger
-
- # Check for config file and merge arguments if present
- if "--config" in argv:
- # Extract boolean actions from the parser to handle them correctly
- parser = argparse.ArgumentParser()
- ServerArgs.add_cli_args(parser)
-
- # Get boolean action destinations
- boolean_actions = []
- for action in parser._actions:
- if hasattr(action, "dest") and hasattr(action, "action"):
- if action.action in ["store_true", "store_false"]:
- boolean_actions.append(action.dest)
-
- # Merge config file arguments with CLI arguments
- config_merger = ConfigArgumentMerger(boolean_actions=boolean_actions)
- argv = config_merger.merge_config_with_args(argv)
-
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv)
@@ -3217,6 +3175,7 @@ def auto_choose_speculative_params(self: ServerArgs):
# The default value for llama
return (5, 4, 8)
elif arch in [
+ "DeepseekV32ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
"GptOssForCausalLM",
diff --git a/python/sglang/srt/server_args_config_parser.py b/python/sglang/srt/server_args_config_parser.py
deleted file mode 100644
index 74dc67677..000000000
--- a/python/sglang/srt/server_args_config_parser.py
+++ /dev/null
@@ -1,146 +0,0 @@
-"""
-Configuration argument parser for command-line applications.
-Handles merging of YAML configuration files with command-line arguments.
-"""
-
-import logging
-from pathlib import Path
-from typing import Any, Dict, List, Union
-
-import yaml
-
-logger = logging.getLogger(__name__)
-
-
-class ConfigArgumentMerger:
- """Handles merging of configuration file arguments with command-line arguments."""
-
- def __init__(self, boolean_actions: List[str] = None):
- """Initialize with list of boolean action destinations."""
- self.boolean_actions = boolean_actions or []
-
- def merge_config_with_args(self, cli_args: List[str]) -> List[str]:
- """
- Merge configuration file arguments with command-line arguments.
-
- Configuration arguments are inserted after the subcommand to maintain
- proper precedence: CLI > Config > Defaults
-
- Args:
- cli_args: List of command-line arguments
-
- Returns:
- Merged argument list with config values inserted
-
- Raises:
- ValueError: If multiple config files specified or no config file provided
- """
- config_file_path = self._extract_config_file_path(cli_args)
- if not config_file_path:
- return cli_args
-
- config_args = self._parse_yaml_config(config_file_path)
- return self._insert_config_args(cli_args, config_args, config_file_path)
-
- def _extract_config_file_path(self, args: List[str]) -> str:
- """Extract the config file path from arguments."""
- config_indices = [i for i, arg in enumerate(args) if arg == "--config"]
-
- if len(config_indices) > 1:
- raise ValueError("Multiple config files specified! Only one allowed.")
-
- if not config_indices:
- return None
-
- config_index = config_indices[0]
- if config_index == len(args) - 1:
- raise ValueError("No config file specified after --config flag!")
-
- return args[config_index + 1]
-
- def _insert_config_args(
- self, cli_args: List[str], config_args: List[str], config_file_path: str
- ) -> List[str]:
- """Insert configuration arguments into the CLI argument list."""
- config_index = cli_args.index("--config")
-
- # Split arguments around config file
- before_config = cli_args[:config_index]
- after_config = cli_args[config_index + 2 :] # Skip --config and file path
-
- # Simple merge: config args + CLI args
- return config_args + before_config + after_config
-
- def _parse_yaml_config(self, file_path: str) -> List[str]:
- """
- Parse YAML configuration file and convert to argument list.
-
- Args:
- file_path: Path to the YAML configuration file
-
- Returns:
- List of arguments in format ['--key', 'value', ...]
-
- Raises:
- ValueError: If file is not YAML or cannot be read
- """
- self._validate_yaml_file(file_path)
-
- try:
- with open(file_path, "r") as file:
- config_data = yaml.safe_load(file)
- except Exception as e:
- logger.error(f"Failed to read config file {file_path}: {e}")
- raise
-
- # Handle empty files or None content
- if config_data is None:
- config_data = {}
-
- if not isinstance(config_data, dict):
- raise ValueError("Config file must contain a dictionary at root level")
-
- return self._convert_config_to_args(config_data)
-
- def _validate_yaml_file(self, file_path: str) -> None:
- """Validate that the file is a YAML file."""
- path = Path(file_path)
- if path.suffix.lower() not in [".yaml", ".yml"]:
- raise ValueError(f"Config file must be YAML format, got: {path.suffix}")
-
- if not path.exists():
- raise ValueError(f"Config file not found: {file_path}")
-
- def _convert_config_to_args(self, config: Dict[str, Any]) -> List[str]:
- """Convert configuration dictionary to argument list."""
- args = []
-
- for key, value in config.items():
- if isinstance(value, bool):
- self._add_boolean_arg(args, key, value)
- elif isinstance(value, list):
- self._add_list_arg(args, key, value)
- else:
- self._add_scalar_arg(args, key, value)
-
- return args
-
- def _add_boolean_arg(self, args: List[str], key: str, value: bool) -> None:
- """Add boolean argument to the list."""
- if key in self.boolean_actions:
- # For boolean actions, always add the flag and value
- args.extend([f"--{key}", str(value).lower()])
- else:
- # For regular booleans, only add flag if True
- if value:
- args.append(f"--{key}")
-
- def _add_list_arg(self, args: List[str], key: str, value: List[Any]) -> None:
- """Add list argument to the list."""
- if value: # Only add if list is not empty
- args.append(f"--{key}")
- args.extend(str(item) for item in value)
-
- def _add_scalar_arg(self, args: List[str], key: str, value: Any) -> None:
- """Add scalar argument to the list."""
- args.extend([f"--{key}", str(value)])
diff --git a/python/sglang/srt/speculative/cpp_ngram/.clang-format b/python/sglang/srt/speculative/cpp_lookahead/.clang-format
similarity index 100%
rename from python/sglang/srt/speculative/cpp_ngram/.clang-format
rename to python/sglang/srt/speculative/cpp_lookahead/.clang-format
diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp b/python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp
similarity index 91%
rename from python/sglang/srt/speculative/cpp_ngram/ngram.cpp
rename to python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp
index 51172c5dd..c47ebcd8d 100644
--- a/python/sglang/srt/speculative/cpp_ngram/ngram.cpp
+++ b/python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp
@@ -1,16 +1,16 @@
-#include "ngram.h"
+#include "lookahead.h"
#include
#include
-namespace ngram {
+namespace lookahead {
struct Node {
std::unordered_map next;
};
-Ngram::Result fillResult(int last_token, int draft_token_num, std::vector& tree, int root) {
- Ngram::Result info;
+Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector& tree, int root) {
+ Lookahead::Result info;
std::vector prevs;
info.token.reserve(draft_token_num);
prevs.reserve(draft_token_num);
@@ -50,7 +50,7 @@ Ngram::Result fillResult(int last_token, int draft_token_num, std::vector&
return info;
}
-Ngram::Ngram(size_t capacity, const Param& param) {
+Lookahead::Lookahead(size_t capacity, const Param& param) {
param_ = param;
nodes_.resize(capacity);
for (auto& node : nodes_) {
@@ -116,16 +116,17 @@ Ngram::Ngram(size_t capacity, const Param& param) {
}
quit_flag_ = false;
- insert_worker_ = std::thread(&Ngram::insert, this);
+ insert_worker_ = std::thread(&Lookahead::insert, this);
}
-Ngram::~Ngram() {
+Lookahead::~Lookahead() {
quit_flag_ = true;
insert_queue_.close();
insert_worker_.join();
}
-std::vector> Ngram::match(const std::vector& tokens, size_t batch_size) const {
+std::vector>
+Lookahead::match(const std::vector& tokens, size_t batch_size) const {
auto draft_token_num = param_.get_draft_token_num(batch_size);
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
auto max_match_window_size = param_.max_match_window_size;
@@ -153,7 +154,7 @@ std::vector> Ngram::match(const std::vector= free_node_count_ + count)) {
throw std::runtime_error(
"Insufficient node size to release required nodes. "
@@ -176,13 +177,13 @@ void Ngram::squeeze(size_t count) {
}
}
-void Ngram::synchronize() const {
+void Lookahead::synchronize() const {
while (!insert_queue_.empty()) {
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
}
-void Ngram::insert() {
+void Lookahead::insert() {
while (!quit_flag_) {
std::vector data;
if (!insert_queue_.dequeue(data)) {
@@ -238,13 +239,13 @@ void Ngram::insert() {
}
}
-void Ngram::asyncInsert(std::vector>&& tokens) {
+void Lookahead::asyncInsert(std::vector>&& tokens) {
for (auto&& token : tokens) {
insert_queue_.enqueue(std::move(token));
}
}
-Ngram::Result Ngram::matchBFS(const std::vector& tokens, size_t batch_size) const {
+Lookahead::Result Lookahead::matchBFS(const std::vector& tokens, size_t batch_size) const {
std::vector> nodes = match(tokens, batch_size);
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
@@ -283,7 +284,7 @@ Ngram::Result Ngram::matchBFS(const std::vector& tokens, size_t batch_s
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
}
-Ngram::Result Ngram::matchProb(const std::vector& tokens, size_t batch_size) const {
+Lookahead::Result Lookahead::matchProb(const std::vector& tokens, size_t batch_size) const {
std::vector> nodes = match(tokens, batch_size);
auto draft_token_num = param_.get_draft_token_num(batch_size);
@@ -345,10 +346,10 @@ Ngram::Result Ngram::matchProb(const std::vector& tokens, size_t batch_
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
}
-Ngram::Result Ngram::batchMatch(const std::vector>& tokens) const {
+Lookahead::Result Lookahead::batchMatch(const std::vector>& tokens) const {
std::unique_lock lock(mutex_);
Result merged_result;
- auto match_func = param_.match_type == "BFS" ? &Ngram::matchBFS : &Ngram::matchProb;
+ auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb;
for (const auto& tks : tokens) {
Result res = (this->*match_func)(tks, tokens.size());
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
@@ -357,7 +358,7 @@ Ngram::Result Ngram::batchMatch(const std::vector>& tokens)
return merged_result;
}
-void Ngram::Result::truncate(size_t n) {
+void Lookahead::Result::truncate(size_t n) {
if (n < token.size()) {
int full_n = token.size();
for (int i = 1; i < n; ++i) {
@@ -368,4 +369,4 @@ void Ngram::Result::truncate(size_t n) {
}
}
-} // namespace ngram
+} // namespace lookahead
diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram.h b/python/sglang/srt/speculative/cpp_lookahead/lookahead.h
similarity index 91%
rename from python/sglang/srt/speculative/cpp_ngram/ngram.h
rename to python/sglang/srt/speculative/cpp_lookahead/lookahead.h
index bf0af0df9..9c6c82c92 100644
--- a/python/sglang/srt/speculative/cpp_ngram/ngram.h
+++ b/python/sglang/srt/speculative/cpp_lookahead/lookahead.h
@@ -15,7 +15,7 @@
#include "param.h"
#include "queue.h"
-namespace ngram {
+namespace lookahead {
struct TrieNode {
std::unordered_map child;
@@ -34,7 +34,7 @@ struct TrieNode {
std::multiset sorted_children;
};
-class Ngram {
+class Lookahead {
std::vector nodes_;
std::vector node_pool_;
size_t free_node_count_;
@@ -61,12 +61,12 @@ class Ngram {
std::vector> match_tmp_data_;
public:
- Ngram(size_t capacity, const Param& param);
- Ngram() = default;
- ~Ngram();
+ Lookahead(size_t capacity, const Param& param);
+ Lookahead() = default;
+ ~Lookahead();
- static Ngram& instance() {
- static Ngram instance;
+ static Lookahead& instance() {
+ static Lookahead instance;
return instance;
}
@@ -107,4 +107,4 @@ class Ngram {
void insert();
};
-} // namespace ngram
+} // namespace lookahead
diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_cache.py b/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py
similarity index 91%
rename from python/sglang/srt/speculative/cpp_ngram/ngram_cache.py
rename to python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py
index 8b1eb8eea..871b60878 100644
--- a/python/sglang/srt/speculative/cpp_ngram/ngram_cache.py
+++ b/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
+# from sglang.op.lookahead import Lookahead, Param
+
import logging
import os
from typing import List, Tuple
@@ -10,17 +12,17 @@ from torch.utils.cpp_extension import load
logger = logging.getLogger(__name__)
_abs_path = os.path.dirname(os.path.abspath(__file__))
-ngram_cache_cpp = load(
- name="ngram_cache_cpp",
+lookahead_cache_cpp = load(
+ name="lookahead_cache_cpp",
sources=[
- f"{_abs_path}/ngram_cache_binding.cpp",
- f"{_abs_path}/ngram.cpp",
+ f"{_abs_path}/lookahead_cache_binding.cpp",
+ f"{_abs_path}/lookahead.cpp",
],
extra_cflags=["-O3", "-std=c++20"],
)
-class NgramCache:
+class LookaheadCache:
def __init__(
self,
branch_length=18,
@@ -32,7 +34,7 @@ class NgramCache:
match_type="BFS",
capacity=1000000,
):
- param = ngram_cache_cpp.Param()
+ param = lookahead_cache_cpp.Param()
param.branch_length = branch_length
param.min_match_window_size = min_match_window_size
param.max_match_window_size = max_match_window_size
@@ -40,7 +42,7 @@ class NgramCache:
param.max_bfs_breadth = max_bfs_breadth
param.draft_token_num = draft_token_num
param.match_type = match_type
- self.cache = ngram_cache_cpp.Ngram(capacity, param)
+ self.cache = lookahead_cache_cpp.Lookahead(capacity, param)
self.default_mask = np.ones((1, 1), dtype=np.int64)
self.draft_token_num = draft_token_num
@@ -129,7 +131,7 @@ if __name__ == "__main__":
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
]
- cache = NgramCache(branch_length=12, draft_token_num=8)
+ cache = LookaheadCache(branch_length=12, draft_token_num=8)
cache.batch_put(token_ids)
cache.synchronize()
diff --git a/python/sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp b/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp
similarity index 71%
rename from python/sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp
rename to python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp
index ac5b931f9..8c48a66ae 100644
--- a/python/sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp
+++ b/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp
@@ -1,19 +1,19 @@
#include
#include
-#include "ngram.h"
+#include "lookahead.h"
-PYBIND11_MODULE(ngram_cache_cpp, m) {
- using namespace ngram;
+PYBIND11_MODULE(lookahead_cache_cpp, m) {
+ using namespace lookahead;
namespace py = pybind11;
m.doc() = "";
- py::class_(m, "Ngram")
+ py::class_(m, "Lookahead")
.def(py::init(), py::arg("capacity"), py::arg("param"))
- .def("asyncInsert", &Ngram::asyncInsert, "")
- .def("batchMatch", &Ngram::batchMatch, "")
- .def("reset", &Ngram::reset, "")
- .def("synchronize", &Ngram::synchronize, "");
+ .def("asyncInsert", &Lookahead::asyncInsert, "")
+ .def("batchMatch", &Lookahead::batchMatch, "")
+ .def("reset", &Lookahead::reset, "")
+ .def("synchronize", &Lookahead::synchronize, "");
py::class_(m, "Param")
.def(py::init<>())
@@ -35,9 +35,9 @@ PYBIND11_MODULE(ngram_cache_cpp, m) {
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
.def("detail", &Param::detail, "");
- py::class_(m, "Result")
+ py::class_(m, "Result")
.def(py::init<>())
- .def_readwrite("token", &Ngram::Result::token)
- .def_readwrite("mask", &Ngram::Result::mask)
- .def("truncate", &Ngram::Result::truncate);
+ .def_readwrite("token", &Lookahead::Result::token)
+ .def_readwrite("mask", &Lookahead::Result::mask)
+ .def("truncate", &Lookahead::Result::truncate);
}
diff --git a/python/sglang/srt/speculative/cpp_ngram/param.h b/python/sglang/srt/speculative/cpp_lookahead/param.h
similarity index 98%
rename from python/sglang/srt/speculative/cpp_ngram/param.h
rename to python/sglang/srt/speculative/cpp_lookahead/param.h
index 967832ad6..2d8b1f875 100644
--- a/python/sglang/srt/speculative/cpp_ngram/param.h
+++ b/python/sglang/srt/speculative/cpp_lookahead/param.h
@@ -9,7 +9,7 @@
#include
#include
-namespace ngram {
+namespace lookahead {
struct Param {
bool enable;
@@ -122,4 +122,4 @@ struct Param {
}
};
-} // namespace ngram
+} // namespace lookahead
diff --git a/python/sglang/srt/speculative/cpp_ngram/queue.h b/python/sglang/srt/speculative/cpp_lookahead/queue.h
similarity index 100%
rename from python/sglang/srt/speculative/cpp_ngram/queue.h
rename to python/sglang/srt/speculative/cpp_lookahead/queue.h
diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py
index 03270b48f..e6c55df18 100644
--- a/python/sglang/srt/speculative/eagle_utils.py
+++ b/python/sglang/srt/speculative/eagle_utils.py
@@ -13,7 +13,6 @@ import triton
import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
-from sglang.srt.environ import envs
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
@@ -24,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
-from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
+from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda():
@@ -43,8 +42,8 @@ logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
-SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
-SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
+SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
+SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
@@ -501,12 +500,13 @@ class EagleVerifyInput:
deterministic=True,
)
- if SIMULATE_ACC_LEN > 0.0:
+ if SIMULATE_ACC_LEN:
# Do simulation
accept_index = _generate_simulated_accept_index(
accept_index=accept_index,
predict=predict, # mutable
accept_length=accept_length, # mutable
+ simulate_acc_len=SIMULATE_ACC_LEN,
bs=bs,
spec_steps=self.spec_steps,
)
@@ -1131,16 +1131,14 @@ def _generate_simulated_accept_index(
accept_index,
predict,
accept_length,
+ simulate_acc_len,
bs,
spec_steps,
- simulate_acc_len: float = SIMULATE_ACC_LEN,
- simulate_acc_method: str = SIMULATE_ACC_METHOD,
):
- assert simulate_acc_len > 0.0
-
- if simulate_acc_method == "multinomial":
+ simulate_acc_len_float = float(simulate_acc_len)
+ if SIMULATE_ACC_METHOD == "multinomial":
simulated_values = torch.normal(
- mean=simulate_acc_len,
+ mean=simulate_acc_len_float,
std=1.0,
size=(1,),
device="cpu",
@@ -1148,19 +1146,19 @@ def _generate_simulated_accept_index(
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item())
- elif simulate_acc_method == "match-expected":
+ elif SIMULATE_ACC_METHOD == "match-expected":
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
- simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
- lower = int(simulate_acc_len // 1)
+ simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
+ lower = int(simulate_acc_len_float // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper:
simulate_acc_len = lower
else:
- weight_upper = simulate_acc_len - lower
+ weight_upper = simulate_acc_len_float - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
diff --git a/python/sglang/srt/speculative/ngram_utils.py b/python/sglang/srt/speculative/lookahead_utils.py
similarity index 98%
rename from python/sglang/srt/speculative/ngram_utils.py
rename to python/sglang/srt/speculative/lookahead_utils.py
index d0e80c0a4..5ca6cb025 100644
--- a/python/sglang/srt/speculative/ngram_utils.py
+++ b/python/sglang/srt/speculative/lookahead_utils.py
@@ -42,7 +42,7 @@ elif is_hip():
@dataclass
-class NgramVerifyInput:
+class LookaheadVerifyInput:
def __init__(
self,
draft_token: torch.Tensor,
@@ -405,8 +405,8 @@ class NgramVerifyInput:
return logits_output, self.verified_id, self.accept_length.sum().item()
- def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
+ def filter_batch(self, new_indices: torch.Tensor):
pass
- def merge_batch(self, spec_info: NgramVerifyInput):
+ def merge_batch(self, spec_info: LookaheadVerifyInput):
pass
diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/lookahead_worker.py
similarity index 86%
rename from python/sglang/srt/speculative/ngram_worker.py
rename to python/sglang/srt/speculative/lookahead_worker.py
index cb0155911..040078ac7 100644
--- a/python/sglang/srt/speculative/ngram_worker.py
+++ b/python/sglang/srt/speculative/lookahead_worker.py
@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import ServerArgs
-from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
-from sglang.srt.speculative.ngram_utils import NgramVerifyInput
+from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache
+from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import broadcast_pyobj
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
USE_FULL_MASK = True
-class NGRAMWorker:
+class LOOKAHEADWorker:
def __init__(
self,
server_args: ServerArgs,
@@ -38,9 +38,9 @@ class NGRAMWorker:
self.tp_rank = tp_rank
self.page_size = server_args.page_size
self.draft_token_num: int = server_args.speculative_num_draft_tokens
- self.branch_length: int = server_args.speculative_ngram_branch_length
+ self.branch_length: int = server_args.speculative_lookahead_branch_length
self.max_match_window_size: int = (
- server_args.speculative_ngram_max_match_window_size
+ server_args.speculative_lookahead_max_match_window_size
)
self.max_batch_size = target_worker.max_running_requests
@@ -48,18 +48,18 @@ class NGRAMWorker:
self._init_preallocated_tensors()
- self.ngram_cache = NgramCache(
- min_match_window_size=server_args.speculative_ngram_min_match_window_size,
- max_match_window_size=server_args.speculative_ngram_max_match_window_size,
- min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
- max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
- capacity=server_args.speculative_ngram_capacity,
- branch_length=server_args.speculative_ngram_branch_length,
+ self.lookahead_cache = LookaheadCache(
+ min_match_window_size=server_args.speculative_lookahead_min_match_window_size,
+ max_match_window_size=server_args.speculative_lookahead_max_match_window_size,
+ min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth,
+ max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth,
+ capacity=server_args.speculative_lookahead_capacity,
+ branch_length=server_args.speculative_lookahead_branch_length,
draft_token_num=server_args.speculative_num_draft_tokens,
)
def clear_cache_pool(self):
- self.ngram_cache.reset()
+ self.lookahead_cache.reset()
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
seq2_len = len(seq2)
@@ -124,14 +124,14 @@ class NGRAMWorker:
) -> tuple[np.ndarray, np.ndarray]:
bs = batch.batch_size()
- self.ngram_cache.synchronize()
+ self.lookahead_cache.synchronize()
batch_tokens = []
for req in batch.reqs:
check_token = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.max_match_window_size
)
batch_tokens.append(check_token)
- req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
+ req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens)
total_draft_token_num = len(req_drafts)
# Check if speculative decoding is needed; here we always enforce it
@@ -184,9 +184,9 @@ class NGRAMWorker:
tree_mask.append(req_mask.flatten())
tree_mask = torch.cat(tree_mask, dim=0)
- batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
+ batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD
batch.forward_mode = ForwardMode.TARGET_VERIFY
- batch.spec_info = NgramVerifyInput(
+ batch.spec_info = LookaheadVerifyInput(
draft_tokens,
tree_mask,
positions,
@@ -197,7 +197,7 @@ class NGRAMWorker:
)
batch.spec_info.prepare_for_verify(batch, self.page_size)
- def _update_ngram_cache(self, batch: ScheduleBatch):
+ def _update_lookahead_cache(self, batch: ScheduleBatch):
batch_tokens = []
for req in batch.reqs:
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
@@ -209,7 +209,7 @@ class NGRAMWorker:
req.origin_input_ids, req.output_ids, self.branch_length
)
batch_tokens.append(put_ids)
- self.ngram_cache.batch_put(batch_tokens)
+ self.lookahead_cache.batch_put(batch_tokens)
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
self._prepare_for_speculative_decoding(batch)
@@ -227,7 +227,7 @@ class NGRAMWorker:
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
batch, logits_output, self.page_size
)
- self._update_ngram_cache(batch)
+ self._update_lookahead_cache(batch)
batch.forward_mode = ForwardMode.DECODE
else:
diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py
index 64a02f19e..a865d0ff6 100644
--- a/python/sglang/srt/speculative/spec_info.py
+++ b/python/sglang/srt/speculative/spec_info.py
@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
EAGLE = auto()
EAGLE3 = auto()
STANDALONE = auto()
- NGRAM = auto()
+ LOOKAHEAD = auto()
def is_none(self):
return self == SpeculativeAlgorithm.NONE
@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum):
def is_standalone(self):
return self == SpeculativeAlgorithm.STANDALONE
- def is_ngram(self):
- return self == SpeculativeAlgorithm.NGRAM
+ def is_lookahead(self):
+ return self == SpeculativeAlgorithm.LOOKAHEAD
@staticmethod
def from_string(name: str):
@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum):
"EAGLE": SpeculativeAlgorithm.EAGLE,
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
- "NGRAM": SpeculativeAlgorithm.NGRAM,
+ "LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD,
None: SpeculativeAlgorithm.NONE,
}
if name is not None:
diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py
index 82717b382..e02bc1fd2 100644
--- a/python/sglang/srt/two_batch_overlap.py
+++ b/python/sglang/srt/two_batch_overlap.py
@@ -31,7 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
-from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
+from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index 8038ccf8a..973efa8e0 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -22,7 +22,6 @@ import ctypes
import dataclasses
import functools
import importlib
-import inspect
import io
import ipaddress
import itertools
@@ -195,7 +194,7 @@ _warned_bool_env_var_keys = set()
def get_bool_env_var(name: str, default: str = "false") -> bool:
- # FIXME: move your environment variable to sglang.srt.environ
+ # FIXME: move your environment variable to sglang.environ
value = os.getenv(name, default)
value = value.lower()
@@ -213,7 +212,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
def get_int_env_var(name: str, default: int = 0) -> int:
- # FIXME: move your environment variable to sglang.srt.environ
+ # FIXME: move your environment variable to sglang.environ
value = os.getenv(name)
if value is None or not value.strip():
return default
@@ -471,7 +470,7 @@ def is_pin_memory_available() -> bool:
class LayerFn(Protocol):
- def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
+ def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
def make_layers(
@@ -482,7 +481,7 @@ def make_layers(
prefix: str = "",
return_tuple: bool = False,
offloader_kwargs: Dict[str, Any] = {},
-) -> Tuple[int, int, torch.nn.ModuleList]:
+) -> Tuple[torch.nn.Module, int, int]:
"""Make a list of layers with the given layer function"""
# circula imports
from sglang.srt.distributed import get_pp_indices
@@ -518,50 +517,6 @@ def make_layers(
return modules, start_layer, end_layer
-cmo_stream = None
-
-
-def get_cmo_stream():
- """
- Cache Management Operation(CMO).
- Launch a new stream to prefetch the weight of matmul when running other
- AIV or communication kernels, aiming to overlap the memory access time.
- """
- global cmo_stream
- if cmo_stream is None:
- cmo_stream = torch.get_device_module().Stream()
- return cmo_stream
-
-
-def prepare_weight_cache(handle, cache):
- import torch_npu
-
- NPU_PREFETCH_MAX_SIZE_BYTES = (
- 1000000000 # 1GB, a large value to prefetch entire weight
- )
- stream = get_cmo_stream()
- stream.wait_stream(torch.npu.current_stream())
- with torch.npu.stream(stream):
- if isinstance(cache, list):
- for weight in cache:
- torch_npu.npu_prefetch(
- weight,
- handle,
- NPU_PREFETCH_MAX_SIZE_BYTES,
- )
- else:
- torch_npu.npu_prefetch(
- cache,
- handle,
- NPU_PREFETCH_MAX_SIZE_BYTES,
- )
-
-
-def wait_cmo_stream():
- cur_stream = torch.get_device_module().current_stream()
- cur_stream.wait_stream(get_cmo_stream())
-
-
def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
@@ -2054,6 +2009,13 @@ def set_uvicorn_logging_configs():
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
+def get_ip() -> Optional[str]:
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
+ if host_ip:
+ return host_ip
+ return None
+
+
def get_open_port() -> int:
port = os.getenv("SGLANG_PORT")
if port is not None:
@@ -2393,10 +2355,8 @@ def get_local_ip_auto(fallback: str = None) -> str:
2. Network interface enumeration via get_local_ip_by_nic()
3. Remote connection method via get_local_ip_by_remote()
"""
- # Try environment variable
- host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
- if host_ip:
- return host_ip
+ if ip := get_ip():
+ return ip
logger.debug("get_ip failed")
# Fallback
if ip := get_local_ip_by_nic():
@@ -2460,7 +2420,7 @@ class BumpAllocator:
def log_info_on_rank0(logger, msg):
from sglang.srt.distributed import get_tensor_model_parallel_rank
- if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
+ if get_tensor_model_parallel_rank() == 0:
logger.info(msg)
@@ -3220,120 +3180,3 @@ def get_extend_input_len_swa_limit(
# and we can only free out-of-sliding-window kv indices after each prefill.
# 3. page_size is because we want to have 1 token extra for generated tokens.
return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
-
-
-class CachedKernel:
- """
- Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
-
- This wrapper caches compiled Triton kernels based on keys extracted by a
- user-provided key function to avoid redundant compilations.
- """
-
- def __init__(self, fn, key_fn=None):
- self.fn = fn
- assert isinstance(fn, triton.runtime.jit.JITFunction)
-
- original_fn = fn.fn
- self.signature = inspect.signature(original_fn)
- self.param_names = tuple(self.signature.parameters.keys())
- self.num_args = len(self.param_names)
-
- # Check that no parameters have default values
- for name, param in self.signature.parameters.items():
- assert (
- param.default is inspect.Parameter.empty
- ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
-
- functools.update_wrapper(self, original_fn)
- self.kernel_cache = {}
-
- # Store the key function
- self.key_fn = key_fn
-
- def __getitem__(self, grid):
- """
- Index with grid to get a launcher function.
- Returns a launcher that will handle caching based on the key function.
- """
- assert (
- isinstance(grid, tuple) and len(grid) <= 3
- ), "Grid must be a tuple with at most 3 dimensions."
-
- # Normalize grid once
- if len(grid) < 3:
- grid = grid + (1,) * (3 - len(grid))
-
- def launcher(*args, **kwargs):
- cache_key = self.key_fn(args, kwargs)
-
- cached_kernel = self.kernel_cache.get(cache_key)
-
- if cached_kernel is None:
- # First time: compile and cache the kernel
- cached_kernel = self.fn[grid](*args, **kwargs)
- self.kernel_cache[cache_key] = cached_kernel
- return cached_kernel
- else:
- # Use cached kernel
- all_args = self._build_args(args, kwargs)
- cached_kernel[grid](*all_args)
- return cached_kernel
-
- return launcher
-
- def _build_args(self, args, kwargs):
- """
- Build the complete argument list for kernel invocation.
- """
- complete_args = list(args)
-
- for i in range(len(args), self.num_args):
- name = self.param_names[i]
- value = kwargs.get(name, inspect.Parameter.empty)
- if value is not inspect.Parameter.empty:
- complete_args.append(value)
- else:
- raise ValueError(f"Missing argument: {name}")
-
- return complete_args
-
- def _clear_cache(self):
- """
- Clear the kernel cache for testing purposes.
- """
- self.kernel_cache.clear()
-
-
-def cached_triton_kernel(key_fn=None):
- """
- Decorator that enables key-based caching for Triton kernels using a key function.
-
- It essentially bypasses Triton's built-in caching mechanism, allowing users to
- define their own caching strategy based on kernel parameters. This helps reduce
- the heavy overheads of Triton kernel launch when the kernel specialization dispatch
- is simple.
-
- Usage:
- @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
- @triton.jit
- def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
- ...
-
- # Invoke normally
- my_kernel[grid](x, y, BLOCK_SIZE=1024)
-
- Args:
- key_fn: A function that takes (args, kwargs) and returns the cache key(s).
- The key can be a single value or a tuple of values.
-
- Returns:
- A decorator that wraps the kernel with caching functionality.
-
- Note: Kernels with default parameter values are not supported and will raise an assertion error.
- """
-
- def decorator(fn):
- return CachedKernel(fn, key_fn)
-
- return decorator
diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py
index 85f84c36b..9b788cc0a 100644
--- a/python/sglang/test/run_eval.py
+++ b/python/sglang/test/run_eval.py
@@ -60,11 +60,6 @@ def run_eval(args):
from sglang.test.simple_eval_humaneval import HumanEval
eval_obj = HumanEval(args.num_examples, args.num_threads)
- elif args.eval_name == "mmmu":
- # VLM MMMU evaluation with fixed 100 examples by default
- from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval
-
- eval_obj = MMMUVLMEval(args.num_examples, args.num_threads)
else:
raise ValueError(f"Invalid eval name: {args.eval_name}")
@@ -99,8 +94,6 @@ def run_eval(args):
print(f"Total latency: {latency:.3f} s")
print(f"Score: {metrics['score']:.3f}")
- if getattr(args, "return_latency", False):
- return metrics, latency
return metrics
diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py
index b631d0778..1816a703e 100644
--- a/python/sglang/test/simple_eval_common.py
+++ b/python/sglang/test/simple_eval_common.py
@@ -136,7 +136,7 @@ class ChatCompletionSampler(SamplerBase):
self._pack_message("system", self.system_message)
] + message_list
trial = 0
- while trial < 6: # 126 seconds in total
+ while True:
try:
response = self.client.chat.completions.create(
model=self.model,
diff --git a/python/sglang/test/simple_eval_mmmu_vlm.py b/python/sglang/test/simple_eval_mmmu_vlm.py
deleted file mode 100644
index 2f64df004..000000000
--- a/python/sglang/test/simple_eval_mmmu_vlm.py
+++ /dev/null
@@ -1,441 +0,0 @@
-"""
-MMMU evaluation for VLMs using the run_eval simple-evals interface.
-
-"""
-
-from __future__ import annotations
-
-import base64
-import io
-from typing import List, Optional, Tuple
-
-from datasets import concatenate_datasets, load_dataset
-from PIL import Image
-
-from sglang.test import simple_eval_common as common
-from sglang.test.simple_eval_common import (
- HTML_JINJA,
- Eval,
- EvalResult,
- SamplerBase,
- SingleEvalResult,
- map_with_progress,
-)
-
-
-class MMMUVLMEval(Eval):
- DOMAIN_CAT2SUB_CAT = {
- "Art and Design": ["Art", "Art_Theory", "Design", "Music"],
- "Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
- "Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
- "Health and Medicine": [
- "Basic_Medical_Science",
- "Clinical_Medicine",
- "Diagnostics_and_Laboratory_Medicine",
- "Pharmacy",
- "Public_Health",
- ],
- "Humanities and Social Science": [
- "History",
- "Literature",
- "Sociology",
- "Psychology",
- ],
- "Tech and Engineering": [
- "Agriculture",
- "Architecture_and_Engineering",
- "Computer_Science",
- "Electronics",
- "Energy_and_Power",
- "Materials",
- "Mechanical_Engineering",
- ],
- }
-
- def __init__(
- self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
- ):
- """Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
- self.num_examples = num_examples
- self.num_threads = num_threads
- self.seed = seed
- # Prepare samples deterministically across all MMMU subjects (validation split)
- self.samples = self._prepare_mmmu_samples(self.num_examples)
-
- @staticmethod
- def _to_data_uri(image: Image.Image) -> str:
- if image.mode == "RGBA":
- image = image.convert("RGB")
- buf = io.BytesIO()
- image.save(buf, format="PNG")
- b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
- return f"data:image/png;base64,{b64}"
-
- @staticmethod
- def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
- index2ans = {}
- all_choices = []
- ch = ord("A")
- for opt in options:
- letter = chr(ch)
- index2ans[letter] = opt
- all_choices.append(letter)
- ch += 1
- return index2ans, all_choices
-
- def _prepare_mmmu_samples(self, k: int) -> List[dict]:
- # Subjects and domains copied from MMMU data_utils to categorize results
- subjects: List[str] = []
- for subs in self.DOMAIN_CAT2SUB_CAT.values():
- subjects.extend(subs)
-
- # Load validation split of each subject
- datasets = []
- for subj in subjects:
- try:
- d = load_dataset("MMMU/MMMU", subj, split="validation")
- # attach subject info via transform
- d = d.add_column("__subject__", [subj] * len(d))
- datasets.append(d)
- except Exception:
- continue
- if not datasets:
- raise RuntimeError("Failed to load MMMU datasets")
-
- merged = concatenate_datasets(datasets)
-
- # Deterministic selection: sort by id (fallback to subject+index)
- def _key(idx):
- ex = merged[idx]
- return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
-
- order = sorted(range(len(merged)), key=_key)
- picked_indices = order[:k]
-
- samples: List[dict] = []
- for idx in picked_indices:
- ex = merged[idx]
- subject = ex["__subject__"]
- image = ex.get("image_1")
- if image is None or not hasattr(image, "convert"):
- continue
- data_uri = self._to_data_uri(image)
- question = ex.get("question", "")
- answer = ex.get("answer")
- raw_options = ex.get("options")
- question_type = "open"
- index2ans = None
- all_choices = None
- options = None
- if raw_options:
- try:
- options = (
- raw_options
- if isinstance(raw_options, list)
- else list(eval(raw_options))
- )
- if isinstance(options, list) and len(options) > 0:
- index2ans, all_choices = self._build_mc_mapping(options)
- question_type = "multiple-choice"
- except Exception:
- options = None
-
- # Build final textual prompt; include choices if MC
- prompt_text = f"Question: {question}\n\n"
- if options:
- letters = [chr(ord("A") + i) for i in range(len(options))]
- for letter, opt in zip(letters, options):
- prompt_text += f"{letter}) {opt}\n"
- prompt_text += "\nAnswer: "
-
- samples.append(
- {
- "id": ex.get("id", f"{subject}:{idx}"),
- "final_input_prompt": prompt_text,
- "image_data": data_uri,
- "answer": answer,
- "question_type": question_type,
- "index2ans": index2ans,
- "all_choices": all_choices,
- "category": subject,
- }
- )
-
- return samples
-
- @staticmethod
- def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
- """Split a prompt containing an inline image tag into prefix and suffix.
-
- If no tag is present, treat the whole prompt as prefix and empty suffix.
- """
- if "<" in prompt and ">" in prompt:
- prefix = prompt.split("<")[0]
- suffix = prompt.split(">", 1)[1]
- return prefix, suffix
- return prompt, ""
-
- @staticmethod
- def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
- """Split a prompt containing an inline image tag into prefix and suffix.
-
- If no tag is present, treat the whole prompt as prefix and empty suffix.
- """
- # Build a vision+text message for OpenAI-compatible API
- prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
-
- content: List[dict] = []
- if prefix:
- content.append({"type": "text", "text": prefix})
- content.append({"type": "image_url", "image_url": {"url": image_data}})
- if suffix:
- content.append({"type": "text", "text": suffix})
- prompt_messages = [{"role": "user", "content": content}]
-
- return prompt_messages
-
- def __call__(self, sampler: SamplerBase) -> EvalResult:
- def fn(sample: dict):
- prompt = sample["final_input_prompt"]
- image_data = sample["image_data"]
- prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
- prompt, image_data
- )
-
- # Sample
- response_text = sampler(prompt_messages)
-
- # Parse and score
- gold = sample["answer"]
- if (
- sample["question_type"] == "multiple-choice"
- and sample["all_choices"]
- and sample["index2ans"]
- ):
- pred = _parse_multi_choice_response(
- response_text, sample["all_choices"], sample["index2ans"]
- )
- score = 1.0 if (gold is not None and pred == gold) else 0.0
- extracted_answer = pred
- else:
- parsed_list = _parse_open_response(response_text)
- score = (
- 1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
- )
- extracted_answer = ", ".join(map(str, parsed_list))
-
- html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
- prompt_messages=prompt_messages,
- next_message=dict(content=response_text, role="assistant"),
- score=score,
- correct_answer=gold,
- extracted_answer=extracted_answer,
- )
-
- convo = prompt_messages + [dict(content=response_text, role="assistant")]
- return SingleEvalResult(
- html=html_rendered,
- score=score,
- metrics={"__category__": sample["category"]},
- convo=convo,
- )
-
- results = map_with_progress(fn, self.samples, self.num_threads)
-
- # Build category table and overall accuracy
- # Gather per-sample correctness and category
- per_cat_total: dict[str, int] = {}
- per_cat_correct: dict[str, int] = {}
- htmls = []
- convos = []
- scores: List[float] = []
- for r in results:
- # __category__ stored under metrics
- cat = r.metrics.get("__category__") if r.metrics else None
- if cat is None:
- cat = "Unknown"
- per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
- if r.score:
- per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
- htmls.append(r.html)
- convos.append(r.convo)
- if r.score is not None:
- scores.append(r.score)
-
- evaluation_result = {}
- for cat, tot in per_cat_total.items():
- corr = per_cat_correct.get(cat, 0)
- acc = (corr / tot) if tot > 0 else 0.0
- evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
-
- printable_results = {}
- # Domains first
- for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
- acc_sum = 0.0
- num_sum = 0
- for cat in cats:
- if cat in evaluation_result:
- acc_sum += (
- evaluation_result[cat]["acc"]
- * evaluation_result[cat]["num_example"]
- )
- num_sum += evaluation_result[cat]["num_example"]
- if num_sum > 0:
- printable_results[f"Overall-{domain}"] = {
- "num": num_sum,
- "acc": round(acc_sum / num_sum, 3),
- }
- # add each sub-category row if present
- for cat in cats:
- if cat in evaluation_result:
- printable_results[cat] = {
- "num": evaluation_result[cat]["num_example"],
- "acc": evaluation_result[cat]["acc"],
- }
-
- # Overall
- total_num = sum(v["num_example"] for v in evaluation_result.values())
- overall_acc = (
- sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
- / total_num
- if total_num > 0
- else 0.0
- )
- printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
-
- # Build EvalResult
- return EvalResult(
- score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
- )
-
-
-def _parse_multi_choice_response(
- response: str, all_choices: List[str], index2ans: dict
-) -> str:
- # loosely adapted from benchmark mmmu eval
- for char in [",", ".", "!", "?", ";", ":", "'"]:
- response = response.strip(char)
- response = " " + response + " "
-
- # Prefer explicit letter with bracket e.g. (A)
- candidates: List[str] = []
- for choice in all_choices:
- if f"({choice})" in response:
- candidates.append(choice)
- if not candidates:
- for choice in all_choices:
- if f" {choice} " in response:
- candidates.append(choice)
- if not candidates and len(response.split()) > 5:
- # try match by option text
- for idx, ans in index2ans.items():
- if ans and ans.lower() in response.lower():
- candidates.append(idx)
- if not candidates:
- # fallback to first choice
- return all_choices[0]
- if len(candidates) == 1:
- return candidates[0]
- # choose the last occurrence
- starts = []
- for can in candidates:
- pos = response.rfind(f"({can})")
- if pos == -1:
- pos = response.rfind(f" {can} ")
- if pos == -1 and index2ans.get(can):
- pos = response.lower().rfind(index2ans[can].lower())
- starts.append(pos)
- return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
-
-
-def _check_is_number(s: str) -> bool:
- try:
- float(s.replace(",", ""))
- return True
- except Exception:
- return False
-
-
-def _normalize_str(s: str):
- s = s.strip()
- if _check_is_number(s):
- s = s.replace(",", "")
- try:
- v = round(float(s), 2)
- return [v]
- except Exception:
- return [s.lower()]
- return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
-
-
-def _extract_numbers(s: str) -> List[str]:
- import re as _re
-
- pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
- pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
- pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
- return (
- _re.findall(pattern_commas, s)
- + _re.findall(pattern_scientific, s)
- + _re.findall(pattern_simple, s)
- )
-
-
-def _parse_open_response(response: str) -> List[str]:
- import re as _re
-
- def get_key_subresponses(resp: str) -> List[str]:
- resp = resp.strip().strip(".").lower()
- subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
- indicators = [
- "could be ",
- "so ",
- "is ",
- "thus ",
- "therefore ",
- "final ",
- "answer ",
- "result ",
- ]
- keys = []
- for i, s in enumerate(subs):
- cands = [*indicators]
- if i == len(subs) - 1:
- cands.append("=")
- shortest = None
- for ind in cands:
- if ind in s:
- part = s.split(ind)[-1].strip()
- if not shortest or len(part) < len(shortest):
- shortest = part
- if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
- keys.append(shortest)
- return keys or [resp]
-
- key_resps = get_key_subresponses(response)
- pred_list = key_resps.copy()
- for r in key_resps:
- pred_list.extend(_extract_numbers(r))
- out = []
- for x in pred_list:
- out.extend(_normalize_str(x))
- # dedup
- return list(dict.fromkeys(out))
-
-
-def _eval_open(gold, preds: List[str]) -> bool:
- if isinstance(gold, list):
- norm_answers = []
- for ans in gold:
- norm_answers.extend(_normalize_str(ans))
- else:
- norm_answers = _normalize_str(gold)
- for p in preds:
- if isinstance(p, str):
- for na in norm_answers:
- if isinstance(na, str) and na in p:
- return True
- else:
- if p in norm_answers:
- return True
- return False
diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py
index 80202d15e..45271e116 100644
--- a/python/sglang/test/test_block_fp8.py
+++ b/python/sglang/test/test_block_fp8.py
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
w_s,
)
- from deep_gemm import fp8_m_grouped_gemm_nt_masked
+ from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
- fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
out = oe[:, :M, :]
self.assertTrue(
diff --git a/python/sglang/test/test_deterministic.py b/python/sglang/test/test_deterministic.py
index 286902677..8c4e45c7c 100644
--- a/python/sglang/test/test_deterministic.py
+++ b/python/sglang/test/test_deterministic.py
@@ -19,7 +19,7 @@ from sglang.profiler import run_profile
PROMPT_1 = "Tell me about Richard Feynman: "
PROMPT_2 = "Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number."
dirpath = os.path.dirname(__file__)
-with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
+with open("python/sglang/test/long_prompt.txt", "r") as f:
LONG_PROMPT = f.read()
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index 2e9a16896..208b45578 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -14,12 +14,10 @@ import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
-from datetime import datetime
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Awaitable, Callable, List, Optional, Tuple
-from urllib.parse import quote
import aiohttp
import numpy as np
@@ -82,7 +80,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct"
)
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
-DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
+DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
# Other use cases
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
@@ -1469,146 +1467,3 @@ def dump_bench_raw_result(
def _ensure_remove_suffix(text: str, suffix: str):
assert text.endswith(suffix)
return text.removesuffix(suffix)
-
-
-class ModelDeploySetup:
- def __init__(self, model_path: str, extra_args: List[str] = []):
- self.model_path = model_path
- if "--enable-multimodal" not in extra_args:
- extra_args.append("--enable-multimodal")
- if "--trust-remote-code" not in extra_args:
- extra_args.append("--trust-remote-code")
-
- self.extra_args = extra_args
-
-
-class ModelEvalMetrics:
- def __init__(self, accuracy: float, eval_time: float):
- self.accuracy = accuracy
- self.eval_time = eval_time
-
-
-def extract_trace_link_from_bench_one_batch_server_output(output: str) -> str:
- match = re.search(r"\[Profile\]\((.*?)\)", output)
- if match:
- trace_link = match.group(1)
- return trace_link
- return None
-
-
-def parse_models(model_string: str):
- return [model.strip() for model in model_string.split(",") if model.strip()]
-
-
-def check_evaluation_test_results(
- results,
- test_name,
- model_accuracy_thresholds,
- model_latency_thresholds=None,
- model_count=None,
-):
- """
- results: list of tuple of (model_path, accuracy, latency)
- """
- failed_models = []
- if model_latency_thresholds is not None:
- summary = " | model | status | score | score_threshold | latency | latency_threshold | \n"
- summary += "| ----- | ------ | ----- | --------------- | ------- | ----------------- | \n"
- else:
- summary = " | model | status | score | score_threshold | \n"
- summary += "| ----- | ------ | ----- | --------------- | \n"
-
- results_dict = {res[0]: (res[1], res[2]) for res in results}
-
- for model, accuracy_threshold in sorted(model_accuracy_thresholds.items()):
- latency_threshold = (
- model_latency_thresholds.get(model)
- if model_latency_thresholds is not None
- else 1e9
- )
-
- if model in results_dict:
- accuracy, latency = results_dict[model]
- is_success = accuracy >= accuracy_threshold and latency <= latency_threshold
- status_emoji = "✅" if is_success else "❌"
-
- if not is_success:
- if accuracy < accuracy_threshold:
- failed_models.append(
- f"\nScore Check Failed: {model}\n"
- f"Model {model} score ({accuracy:.4f}) is below threshold ({accuracy_threshold:.4f})"
- )
- if latency > latency_threshold:
- failed_models.append(
- f"\nLatency Check Failed: {model}\n"
- f"Model {model} latency ({latency:.4f}) is above threshold ({latency_threshold:.4f})"
- )
-
- if model_latency_thresholds is not None:
- line = f"| {model} | {status_emoji} | {accuracy} | {accuracy_threshold} | {latency} | {latency_threshold}\n"
- else:
- line = (
- f"| {model} | {status_emoji} | {accuracy} | {accuracy_threshold}\n"
- )
- else:
- status_emoji = "❌"
- failed_models.append(f"Model failed to launch or be evaluated: {model}")
- if model_latency_thresholds is not None:
- line = f"| {model} | {status_emoji} | N/A | {accuracy_threshold} | N/A | {latency_threshold}\n"
- else:
- line = f"| {model} | {status_emoji} | N/A | {accuracy_threshold}\n"
-
- summary += line
-
- print(summary)
-
- if is_in_ci():
- write_github_step_summary(f"## {test_name}\n{summary}")
-
- if failed_models:
- print("Some models failed the evaluation.")
- raise AssertionError("\n".join(failed_models))
-
-
-# Bench knobs for bench_one_batch_server (override by env)
-def _parse_int_list_env(name: str, default_val: str):
- val = os.environ.get(name, default_val)
- return [int(x) for x in val.split(",") if x]
-
-
-# Return filenames
-def find_traces_under_path(path: str) -> List[str]:
- results = []
- for _, dirs, files in os.walk(path):
- for file in files:
- if file.endswith(".trace.json.gz"):
- results.append(f"{file}")
- return results
-
-
-def write_results_to_json(model, metrics, mode="a"):
- result = {
- "timestamp": datetime.now().isoformat(),
- "model": model,
- "metrics": metrics,
- "score": metrics["score"],
- }
-
- if "latency" in metrics:
- result["latency"] = (metrics.get("latency"),)
-
- existing_results = []
- if mode == "a" and os.path.exists("results.json"):
- try:
- with open("results.json", "r") as f:
- existing_results = json.load(f)
- except json.JSONDecodeError:
- existing_results = []
-
- if isinstance(existing_results, list):
- existing_results.append(result)
- else:
- existing_results = [result]
-
- with open("results.json", "w") as f:
- json.dump(existing_results, f, indent=2)
diff --git a/python/sglang/utils.py b/python/sglang/utils.py
index 1d62c5df8..91c3454a1 100644
--- a/python/sglang/utils.py
+++ b/python/sglang/utils.py
@@ -1,12 +1,13 @@
"""Common utilities"""
+import functools
import importlib
+import inspect
import json
import logging
import os
import random
import socket
-import ssl
import subprocess
import sys
import time
@@ -22,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
import numpy as np
import pybase64
import requests
+import triton
from IPython.display import HTML, display
from pydantic import BaseModel
from tqdm import tqdm
@@ -156,15 +158,7 @@ def http_request(
data = bytes(dumps(json), encoding="utf-8")
try:
- if sys.version_info >= (3, 13):
- # Python 3.13+: Use SSL context (cafile removed)
- if verify and isinstance(verify, str):
- context = ssl.create_default_context(cafile=verify)
- else:
- context = ssl.create_default_context()
- resp = urllib.request.urlopen(req, data=data, context=context)
- else:
- resp = urllib.request.urlopen(req, data=data, cafile=verify)
+ resp = urllib.request.urlopen(req, data=data, cafile=verify)
return HttpResponse(resp)
except urllib.error.HTTPError as e:
return HttpResponse(e)
@@ -549,3 +543,114 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
+
+
+class CachedKernel:
+ """
+ Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
+
+ This wrapper caches compiled Triton kernels based on keys extracted by a
+ user-provided key function to avoid redundant compilations.
+ """
+
+ def __init__(self, fn, key_fn=None):
+ self.fn = fn
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
+
+ original_fn = fn.fn
+ self.signature = inspect.signature(original_fn)
+ self.param_names = tuple(self.signature.parameters.keys())
+ self.num_args = len(self.param_names)
+
+ # Check that no parameters have default values
+ for name, param in self.signature.parameters.items():
+ assert (
+ param.default is inspect.Parameter.empty
+ ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
+
+ functools.update_wrapper(self, original_fn)
+ self.kernel_cache = {}
+
+ # Store the key function
+ self.key_fn = key_fn
+
+ def __getitem__(self, grid):
+ """
+ Index with grid to get a launcher function.
+ Returns a launcher that will handle caching based on the key function.
+ """
+ assert (
+ isinstance(grid, tuple) and len(grid) <= 3
+ ), "Grid must be a tuple with at most 3 dimensions."
+
+ # Normalize grid once
+ if len(grid) < 3:
+ grid = grid + (1,) * (3 - len(grid))
+
+ def launcher(*args, **kwargs):
+ cache_key = self.key_fn(args, kwargs)
+
+ cached_kernel = self.kernel_cache.get(cache_key)
+
+ if cached_kernel is None:
+ # First time: compile and cache the kernel
+ cached_kernel = self.fn[grid](*args, **kwargs)
+ self.kernel_cache[cache_key] = cached_kernel
+ return cached_kernel
+ else:
+ # Use cached kernel
+ all_args = self._build_args(args, kwargs)
+ cached_kernel[grid](*all_args)
+ return cached_kernel
+
+ return launcher
+
+ def _build_args(self, args, kwargs):
+ """
+ Build the complete argument list for kernel invocation.
+ """
+ complete_args = list(args)
+
+ for i in range(len(args), self.num_args):
+ name = self.param_names[i]
+ value = kwargs.get(name, inspect.Parameter.empty)
+ if value is not inspect.Parameter.empty:
+ complete_args.append(value)
+ else:
+ raise ValueError(f"Missing argument: {name}")
+
+ return complete_args
+
+
+def cached_triton_kernel(key_fn=None):
+ """
+ Decorator that enables key-based caching for Triton kernels using a key function.
+
+ It essentially bypasses Triton's built-in caching mechanism, allowing users to
+ define their own caching strategy based on kernel parameters. This helps reduce
+ the heavy overheads of Triton kernel launch when the kernel specialization dispatch
+ is simple.
+
+ Usage:
+ @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
+ @triton.jit
+ def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
+ ...
+
+ # Invoke normally
+ my_kernel[grid](x, y, BLOCK_SIZE=1024)
+
+ Args:
+ key_fn: A function that takes (args, kwargs) and returns the cache key(s).
+ The key can be a single value or a tuple of values.
+
+ Returns:
+ A decorator that wraps the kernel with caching functionality.
+
+ Note: Kernels with default parameter values are not supported and will raise an assertion error.
+ """
+
+ def decorator(fn):
+ return CachedKernel(fn, key_fn)
+
+ return decorator