3 Commits

Author SHA1 Message Date
maxiao1
7993ed8ddd 适配deepseekv3.2
Some checks failed
CI Monitor / ci-monitor (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx942) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx942-rocm700) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (all, gfx950) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx942) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx942-rocm700) (push) Has been cancelled
Release Docker Images Nightly (AMD) / publish (srt, gfx950) (push) Has been cancelled
Release Docker Images Nightly (Ascend NPU) / build (8.2.rc1, 910b) (push) Has been cancelled
Release Docker Images Nightly (Ascend NPU) / build (8.2.rc1, a3) (push) Has been cancelled
Build and Push Development Docker Images / build-dev-x86 (map[tag:dev type:all version:12.9.1]) (push) Has been cancelled
Build and Push Development Docker Images / build-blackwell-arm (map[tag:blackwell-cu129 type:blackwell_aarch version:12.9.1]) (push) Has been cancelled
Build and Push Development Docker Images / create-manifests (map[arm64_tag:blackwell-cu129-arm64 tag:dev-manifest x86_tag:dev]) (push) Has been cancelled
Nightly Test / nightly-test-eval-text-models (push) Has been cancelled
Nightly Test / nightly-test-perf-text-models (push) Has been cancelled
Nightly Test / nightly-test-eval-vlms (push) Has been cancelled
Nightly Test / nightly-test-perf-vlms (push) Has been cancelled
Nightly Test (AMD) / nightly-test (linux-mi300-gpu-2) (push) Has been cancelled
Nightly Test (AMD) / nightly-test (linux-mi325-gpu-2-nightly) (push) Has been cancelled
Close Inactive Issues / close-inactive-issues (push) Has been cancelled
2025-10-03 20:01:17 +08:00
maxiao1
443a1b4ab3 Update pyproject_other.toml 2025-09-30 10:47:20 +00:00
maxiao
852a49c5cc adapt to dsv32 on dcu 2025-09-30 18:37:31 +08:00
167 changed files with 7597 additions and 8753 deletions

View File

@@ -57,7 +57,7 @@ dependencies = [
"uvicorn",
"uvloop",
"xgrammar==0.1.24",
"sgl-kernel==0.3.13",
"sgl-kernel==0.3.11",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
@@ -67,7 +67,7 @@ dependencies = [
"tiktoken",
"anthropic>=0.20.0",
"torch_memory_saver==0.0.8",
"nvidia-cutlass-dsl==4.2.1",
"nvidia-cutlass-dsl==4.2.0",
]
[project.optional-dependencies]
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
"srt/layers/quantization/configs/*.json",
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
"srt/speculative/cpp_ngram/*.cpp",
"srt/speculative/cpp_ngram/*.h",
"srt/speculative/cpp_lookahead/*.cpp",
"srt/speculative/cpp_lookahead/*.h",
]
[tool.setuptools.packages.find]

View File

@@ -65,30 +65,29 @@ tracing = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.13",
"sgl-kernel==0.3.11",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"flashinfer_python==0.4.0rc1",
"flashinfer_python==0.3.1",
]
blackwell = [
"sglang[runtime_common]",
"sgl-kernel==0.3.13",
"sgl-kernel==0.3.11",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
"cuda-python",
"flashinfer_python==0.4.0rc1",
"nvidia-cutlass-dsl==4.2.1",
"flashinfer_python==0.3.1",
"nvidia-cutlass-dsl==4.2.0",
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
srt_hip = [
"sglang[runtime_common]",
"torch",
"petit_kernel==0.0.2",
"wave-lang==3.7.0",
]

View File

@@ -443,9 +443,11 @@ def latency_test_run_once(
if profile:
profiler.stop()
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
_save_profile_trace_results(profiler, trace_filename)
rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for prefill saved to {profile_filename}"
)
# Decode
decode_latencies = []
@@ -477,10 +479,10 @@ def latency_test_run_once(
if profile and i == output_len / 2:
profiler.stop()
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
_save_profile_trace_results(profiler, trace_filename)
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
)
# Record decode timing from 2nd output

View File

@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
"""
import argparse
@@ -20,17 +19,12 @@ import multiprocessing
import os
import random
import time
from typing import List, Optional, Tuple
from typing import List, Tuple
import numpy as np
import requests
from pydantic import BaseModel
from sglang.bench_serving import (
get_tokenizer,
sample_mmmu_requests,
sample_random_requests,
)
from sglang.bench_serving import get_tokenizer, sample_random_requests
from sglang.profiler import run_profile
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary
class ProfileLinks(BaseModel):
"""Pydantic model for profile trace links."""
extend: Optional[str] = None
decode: Optional[str] = None
class BenchmarkResult(BaseModel):
"""Pydantic model for benchmark results table data, for a single isl and osl"""
model_path: str
run_name: str
batch_size: int
input_len: int
output_len: int
latency: float
ttft: float
input_throughput: float
output_throughput: float
overall_throughput: float
last_gen_throughput: float
acc_length: Optional[float] = None
profile_links: Optional[ProfileLinks] = None
@staticmethod
def help_str() -> str:
return f"""
Note: To view the traces through perfetto-ui, please:
1. open with Google Chrome
2. allow popup
"""
def to_markdown_row(
self, trace_dir, base_url: str = "", relay_base: str = ""
) -> str:
"""Convert this benchmark result to a markdown table row."""
# Calculate costs (assuming H100 pricing for now)
hourly_cost_per_gpu = 2 # $2/hour for one H100
hourly_cost = hourly_cost_per_gpu * 1 # Assuming tp_size = 1 for simplicity
input_util = 0.7
accept_length = (
round(self.acc_length, 2) if self.acc_length is not None else "n/a"
)
itl = 1 / (self.output_throughput / self.batch_size) * 1000
input_cost = 1e6 / (self.input_throughput * input_util) / 3600 * hourly_cost
output_cost = 1e6 / self.output_throughput / 3600 * hourly_cost
def get_perfetto_relay_link_from_trace_file(trace_file: str):
import os
from urllib.parse import quote
rel_path = os.path.relpath(trace_file, trace_dir)
raw_file_link = f"{base_url}/{rel_path}"
relay_link = (
f"{relay_base}?src={quote(raw_file_link, safe='')}"
if relay_base and quote
else raw_file_link
)
return relay_link
# Handle profile links
profile_link = "NA | NA"
if self.profile_links:
if self.profile_links.extend or self.profile_links.decode:
# Create a combined link or use the first available one
trace_files = [self.profile_links.extend, self.profile_links.decode]
trace_files_relay_links = [
f"[trace]({get_perfetto_relay_link_from_trace_file(trace_file)})"
for trace_file in trace_files
]
profile_link = " | ".join(trace_files_relay_links)
# Build the row
return f"| {self.batch_size} | {self.input_len} | {self.latency:.2f} | {self.input_throughput:.2f} | {self.output_throughput:.2f} | {accept_length} | {itl:.2f} | {input_cost:.2f} | {output_cost:.2f} | {profile_link} |\n"
@classmethod
def generate_markdown_report(
cls, trace_dir, results: List["BenchmarkResult"]
) -> str:
"""Generate a markdown report from a list of BenchmarkResult object from a single run."""
import os
summary = f"### {results[0].model_path}\n"
# summary += (
# f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
# )
summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|\n"
summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |\n"
# all results should share the same isl & osl
for result in results:
base_url = os.getenv("TRACE_BASE_URL", "").rstrip("/")
relay_base = os.getenv("PERFETTO_RELAY_URL", "").rstrip("/")
relay_base = "https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
# base_url = "https://github.com/sgl-project/ci-data/traces"
summary += result.to_markdown_row(trace_dir, base_url, relay_base)
return summary
@dataclasses.dataclass
class BenchArgs:
run_name: str = "default"
@@ -158,12 +50,8 @@ class BenchArgs:
profile: bool = False
profile_steps: int = 3
profile_by_stage: bool = False
profile_filename_prefix: str = None
append_to_github_summary: bool = True
dataset_path: str = ""
parallel_batch: bool = False
dataset_name: str = "random"
output_path: Optional[str] = None
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
@@ -179,13 +67,6 @@ class BenchArgs:
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument(
"--dataset-name",
type=str,
default=BenchArgs.dataset_name,
choices=["mmmu", "random"],
help="Name of the dataset to benchmark on.",
)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--client-stream-interval",
@@ -215,36 +96,14 @@ class BenchArgs:
help="Path to the dataset.",
)
parser.add_argument("--parallel-batch", action="store_true")
parser.add_argument(
"--profile-filename-prefix",
type=str,
default=BenchArgs.profile_filename_prefix,
)
parser.add_argument(
"--no-append-to-github-summary",
action="store_false",
dest="append_to_github_summary",
help="Disable appending the output of this run to github ci summary",
)
parser.add_argument(
"--output-path",
type=str,
default=BenchArgs.output_path,
help="Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# use the default value's type to cast the args into correct types.
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
kwargs = {}
for attr, attr_type in attrs:
val = getattr(args, attr)
if attr_type is type(None):
kwargs[attr] = val
else:
kwargs[attr] = attr_type(val)
return cls(**kwargs)
return cls(
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
)
def launch_server_internal(server_args):
@@ -289,35 +148,23 @@ def run_one_case(
run_name: str,
result_filename: str,
tokenizer,
dataset_name="",
profile: bool = False,
profile_steps: int = 3,
profile_by_stage: bool = False,
profile_filename_prefix: str = None,
dataset_path: str = "",
parallel_batch: bool = False,
):
requests.post(url + "/flush_cache")
# TODO: reuse bench_serving.get_dataset ?
if dataset_name == "mmmu":
input_requests = sample_mmmu_requests(
num_requests=batch_size,
tokenizer=tokenizer,
fixed_output_len=output_len,
apply_chat_template=True,
random_sample=False,
)
elif dataset_name == "random":
input_requests = sample_random_requests(
input_len=input_len,
output_len=output_len,
num_prompts=batch_size,
range_ratio=1.0,
tokenizer=tokenizer,
dataset_path=dataset_path,
random_sample=True,
return_text=False,
)
input_requests = sample_random_requests(
input_len=input_len,
output_len=output_len,
num_prompts=batch_size,
range_ratio=1.0,
tokenizer=tokenizer,
dataset_path=dataset_path,
random_sample=True,
return_text=False,
)
use_structured_outputs = False
if use_structured_outputs:
@@ -334,48 +181,26 @@ def run_one_case(
profile_link = None
if profile:
output_dir, profile_name = None, None
if profile_filename_prefix:
output_dir = os.path.dirname(profile_filename_prefix)
profile_name = os.path.basename(profile_filename_prefix)
profile_link: str = run_profile(
url,
profile_steps,
["CPU", "GPU"],
output_dir,
profile_name,
profile_by_stage,
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
)
tic = time.perf_counter()
payload = {
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
"json_schema": json_schema,
"stream_interval": stream_interval,
},
"return_logprob": return_logprob,
"stream": True,
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
}
if dataset_name == "mmmu":
# vlm
input_ids = []
for input_req in input_requests:
input_ids += [tokenizer.encode(input_req.prompt)]
payload["image_data"] = [req.image_data for req in input_requests]
else:
input_ids = [req.prompt for req in input_requests]
payload["input_ids"] = input_ids
response = requests.post(
url + "/generate",
json=payload,
json={
"input_ids": [req.prompt for req in input_requests],
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
"json_schema": json_schema,
"stream_interval": stream_interval,
},
"return_logprob": return_logprob,
"stream": True,
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
},
stream=True,
)
@@ -439,100 +264,10 @@ def run_one_case(
overall_throughput,
last_gen_throughput,
acc_length,
profile_link,
profile_link if profile else None,
)
def save_results_as_json(result: List[Tuple], bench_args: BenchArgs, model: str):
"""Save benchmark results as JSON using Pydantic models."""
json_results = []
# Generate all parameter combinations to match with results
param_combinations = list(
itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
)
)
for i, (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
profile_link,
) in enumerate(result):
# Get the corresponding parameters for this result
bs, input_len, output_len = param_combinations[i]
# Parse profile links if available
profile_links = None
if profile_link:
profile_links = parse_profile_links(
profile_link, batch_size, input_len, output_len
)
benchmark_result = BenchmarkResult(
model_path=model,
run_name=bench_args.run_name,
batch_size=batch_size,
input_len=input_len,
output_len=output_len,
latency=latency,
ttft=ttft,
input_throughput=input_throughput,
output_throughput=output_throughput,
overall_throughput=overall_throughput,
last_gen_throughput=last_gen_throughput,
acc_length=acc_length,
profile_links=profile_links,
)
json_results.append(benchmark_result.model_dump())
# Save to JSON file
with open(bench_args.output_path, "w", encoding="utf-8") as f:
json.dump(json_results, f, indent=2, ensure_ascii=False)
print(f"Results saved as JSON to {bench_args.output_path}")
def parse_profile_links(
profile_dir: str, batch_size: int, input_len: int, output_len: int
) -> Optional[ProfileLinks]:
"""Parse profile directory to extract extend and decode trace file links."""
if not profile_dir or not os.path.exists(profile_dir):
return None
extend_link = None
decode_link = None
# Look for extend/prefill trace files
for file in os.listdir(profile_dir):
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
if "extend" in file.lower() or "prefill" in file.lower():
extend_link = os.path.join(profile_dir, file)
elif "decode" in file.lower():
decode_link = os.path.join(profile_dir, file)
# If no specific extend/decode files found, try to find files with batch/input/output info
if not extend_link or not decode_link:
for file in os.listdir(profile_dir):
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
if f"_batch{batch_size}_input{input_len}_output{output_len}_" in file:
if "prefill" in file.lower() or "extend" in file.lower():
extend_link = os.path.join(profile_dir, file)
elif "decode" in file.lower():
decode_link = os.path.join(profile_dir, file)
if extend_link or decode_link:
return ProfileLinks(extend=extend_link, decode=decode_link)
return None
def get_report_summary(
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
):
@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
return_logprob=bench_args.return_logprob,
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
dataset_name=bench_args.dataset_name,
run_name="",
result_filename="",
tokenizer=tokenizer,
@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
stream_interval=bench_args.client_stream_interval,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
dataset_name=bench_args.dataset_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
profile_filename_prefix=bench_args.profile_filename_prefix,
)
)
@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
dataset_name=bench_args.dataset_name,
profile=bench_args.profile,
profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
profile_filename_prefix=bench_args.profile_filename_prefix,
)[-1],
)
)
@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print(f"\nResults are saved to {bench_args.result_filename}")
# Save results as JSON if output_path is specified
if bench_args.output_path:
save_results_as_json(result, bench_args, model=server_args.model_path)
if not bench_args.show_report:
return
summary = get_report_summary(result, server_args, bench_args)
print(summary)
if is_in_ci() and bench_args.append_to_github_summary:
if is_in_ci():
write_github_step_summary(summary)

View File

@@ -208,10 +208,6 @@ async def async_request_openai_completions(
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
}
if request_func_input.image_data:
payload.update({"image_data": request_func_input.image_data})
headers = get_auth_headers()
output = RequestFuncOutput.init_new(request_func_input)
@@ -1763,9 +1759,7 @@ async def benchmark(
pbar.close()
if "sglang" in backend:
server_info = requests.get(
base_url + "/get_server_info", headers=get_auth_headers()
)
server_info = requests.get(base_url + "/get_server_info")
if server_info.status_code == 200:
server_info_json = server_info.json()
if "decode" in server_info_json:

View File

@@ -124,8 +124,6 @@ class Envs:
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
# Model Parallel
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)

View File

@@ -37,8 +37,8 @@ class GlobalConfig:
)
# Runtime constants: others
self.retract_decode_steps = 20
self.flashinfer_workspace_size = int(
os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
self.flashinfer_workspace_size = os.environ.get(
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
)
# Output tokenization configs

View File

@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
MOVE_ENVS_WARN = """
########################################################################
# For contributors and developers: #
# Please move environment variable definitions to sglang.srt.environ #
# using the following pattern: #
# SGLANG_XXX = EnvBool(False) #
# #
########################################################################
"""
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
from sglang.srt.server_args import print_deprecated_warning
print_deprecated_warning(MOVE_ENVS_WARN)
try:
launch_server(server_args)
finally:

View File

@@ -5,15 +5,6 @@ from typing import List, Optional, Tuple
import torch
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
try:
from lmslim import quant_ops
from lmslim import quant_tools
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
import lightop
except Exception:
print("INFO: Please install lightop if you want to infer awq of marlin.\n")
logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
@@ -184,25 +175,3 @@ def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
def triton_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
best_config:Optional[list] = None) -> torch.Tensor:
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
def triton_int8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.float16,
device: str = "cuda:0",
best_config:Optional[list] = None,
repeat:Optional[int] = 2):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)

View File

@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
JAX = "jax"
REMOTE = "remote"
REMOTE_INSTANCE = "remote_instance"
RDMA = "rdma"
LOCAL_CACHED = "local_cached"
@dataclass
@@ -49,7 +47,6 @@ class LoadConfig:
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
@@ -57,11 +54,6 @@ class LoadConfig:
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
decrypt_max_concurrency: int = -1
tp_rank: Optional[int] = None
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}

View File

@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip, retry
from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
TRANSFORMERS = "transformers"
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0]
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
and getattr(config, "index_topk", None) is not None
)
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_head_dim
def get_nsa_index_topk(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_topk
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
assert is_deepseek_nsa(config)
return config.index_n_heads
class ModelConfig:
def __init__(
self,
@@ -64,20 +88,35 @@ class ModelConfig:
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None:
# Parse args
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.is_draft_model = is_draft_model
self.model_impl = model_impl
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)
# Get hf config
self._maybe_pull_model_tokenizer_from_remote()
self.maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
self.model_path,
trust_remote_code=trust_remote_code,
@@ -85,7 +124,7 @@ class ModelConfig:
model_override_args=self.model_override_args,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_generation_config = get_generation_config(
self.model_path,
trust_remote_code=trust_remote_code,
@@ -93,25 +132,7 @@ class ModelConfig:
**kwargs,
)
# Set enable_multimodal
if enable_multimodal is None:
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False
logger.info(
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
)
else:
enable_multimodal = True
# Config draft model
self._config_draft_model()
# Check model type
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)
@@ -127,70 +148,20 @@ class ModelConfig:
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
)
)
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length and model shapes
self._derive_context_length(context_length)
self._derive_model_shapes()
# Verify quantization
self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes
self.hf_eos_token_id = self._get_hf_eos_token_id()
# multimodal
self.image_token_id = getattr(
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
@staticmethod
def from_server_args(
server_args: ServerArgs,
model_path: str = None,
model_revision: str = None,
**kwargs,
):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=model_revision or server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
**kwargs,
)
def _config_draft_model(self):
is_draft_model = self.is_draft_model
if enable_multimodal is None:
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False
logger.info(
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
)
else:
enable_multimodal = True
if (
is_draft_model
@@ -225,10 +196,31 @@ class ModelConfig:
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
self.hf_config.num_nextn_predict_layers = 1
def _derive_context_length(self, context_length: int):
is_draft_model = self.is_draft_model
derived_context_len = get_context_length(self.hf_text_config)
# Check model type
self.is_generation = is_generation_model(
self.hf_config.architectures, is_embedding
)
self.is_multimodal = enable_multimodal and is_multimodal_model(
self.hf_config.architectures
)
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
self.hf_config.architectures
)
self.is_image_gen = enable_multimodal and is_image_gen_model(
self.hf_config.architectures
)
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None:
if context_length > derived_context_len:
reason = "Target model's" if is_draft_model else "User-specified"
@@ -242,11 +234,6 @@ class ModelConfig:
):
logger.warning(msg)
self.context_len = context_length
if is_draft_model:
self.hf_text_config.max_position_embeddings = context_length
logger.warning(
f"Overriding the draft model's max_position_embeddings to {context_length}."
)
else:
raise ValueError(
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
@@ -256,10 +243,6 @@ class ModelConfig:
else:
self.context_len = derived_context_len
# Transfer context_len to HuggingFace config so models can access it
self.hf_config.context_len = self.context_len
def _derive_model_shapes(self):
# Unify the config keys for hf_text_config
self.head_dim = getattr(
self.hf_text_config,
@@ -270,6 +253,7 @@ class ModelConfig:
# FIXME: temporary special judge for MLA architecture
if (
"DeepseekV2ForCausalLM" in self.hf_config.architectures
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
or "LongcatFlashForCausalLM" in self.hf_config.architectures
@@ -282,6 +266,11 @@ class ModelConfig:
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
self.index_head_dim = (
get_nsa_index_head_dim(self.hf_config)
if is_deepseek_nsa(self.hf_config)
else None
)
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
@@ -354,6 +343,45 @@ class ModelConfig:
)
self.vocab_size = self.hf_text_config.vocab_size
# Verify quantization
self._verify_quantization()
# Verify dual-chunk attention config
self._verify_dual_chunk_attention_config()
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
# multimodal
self.image_token_id = getattr(
self.hf_config, "image_token_id", None
) or getattr(self.hf_config, "image_token_index", None)
@staticmethod
def from_server_args(
server_args: ServerArgs,
model_path: str = None,
model_revision: str = None,
**kwargs,
):
return ModelConfig(
model_path=model_path or server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=model_revision or server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
)
def get_total_num_attention_heads(self) -> int:
return self.num_attention_heads
@@ -454,31 +482,13 @@ class ModelConfig:
from huggingface_hub import HfApi
hf_api = HfApi()
def check_hf_quant_config():
return hf_api.file_exists(
self.model_path, "hf_quant_config.json"
)
# Retry HF API call up to 3 times
file_exists = retry(
check_hf_quant_config,
max_retry=2,
initial_delay=1.0,
max_delay=5.0,
)
if file_exists:
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
quant_cfg = modelopt_quant_config
except huggingface_hub.errors.OfflineModeIsEnabled:
logger.warning(
"Offline mode is enabled, skipping hf_quant_config.json check"
)
except Exception as e:
logger.warning(
f"Failed to check hf_quant_config.json: {self.model_path} {e}"
)
pass
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_config_file = os.path.join(
@@ -508,7 +518,6 @@ class ModelConfig:
"petit_nvfp4",
"quark",
"mxfp4",
"slimquant_w4a8_marlin",
]
optimized_quantization_methods = [
"fp8",
@@ -527,7 +536,6 @@ class ModelConfig:
"qoq",
"w4afp8",
"petit_nvfp4",
"slimquant_w4a8_marlin",
]
compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"],
@@ -608,7 +616,7 @@ class ModelConfig:
"sparse_attention_enabled"
] = True
def _get_hf_eos_token_id(self) -> Optional[Set[int]]:
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids is not None:
# it can be either int or list of int
@@ -628,7 +636,7 @@ class ModelConfig:
eos_ids = eos_ids | generation_eos_ids
return eos_ids
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
def maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
directory in case of remote.
@@ -771,8 +779,6 @@ multimodal_model_archs = [
"Qwen2AudioForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"InternS1ForConditionalGeneration",

View File

@@ -1,586 +0,0 @@
from typing import Optional, Union
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class Qwen3VLVisionConfig(PretrainedConfig):
model_type = "qwen3_vl"
base_config_key = "vision_config"
def __init__(
self,
depth=27,
hidden_size=1152,
hidden_act="gelu_pytorch_tanh",
intermediate_size=4304,
num_heads=16,
in_channels=3,
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
out_hidden_size=3584,
num_position_embeddings=2304,
deepstack_visual_indexes=[8, 16, 24],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
self.deepstack_visual_indexes = deepstack_visual_indexes
class Qwen3VLTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen3VLModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
head_dim (`int`, *optional*, defaults to 128):
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 5000000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig
>>> # Initializing a Qwen3VL style configuration
>>> configuration = Qwen3VLTextConfig()
>>> # Initializing a model from the Qwen3-VL-7B style configuration
>>> model = Qwen3VLTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_vl_text"
base_config_key = "text_config"
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=128000,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=5000000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class Qwen3VLConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
vision_start_token_id (`int`, *optional*, defaults to 151652):
The start token index to encode the image prompt.
vision_end_token_id (`int`, *optional*, defaults to 151653):
The end token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings.
```python
>>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig
>>> # Initializing a Qwen3-VL style configuration
>>> configuration = Qwen3VLConfig()
>>> # Initializing a model from the Qwen3-VL-4B style configuration
>>> model = Qwen3VLForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_vl"
sub_configs = {
"vision_config": Qwen3VLVisionConfig,
"text_config": Qwen3VLTextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
class Qwen3VLMoeTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2MoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 128000):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 5000000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
decoder_sparse_step (`int`, *optional*, defaults to 1):
The frequency of the MoE layer.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert.
num_experts_per_tok (`int`, *optional*, defaults to 4):
Number of selected experts.
num_experts (`int`, *optional*, defaults to 60):
Number of routed experts.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the topk probabilities.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
head_dim (`int`, *optional*):
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
```python
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
>>> # Initializing a Qwen3VLMoe style configuration
>>> configuration = Qwen3VLMoeConfig()
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_vl_moe_text"
base_config_key = "text_config"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3VLMoe`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=5632,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
hidden_act="silu",
max_position_embeddings=128000,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=5000000.0,
attention_bias=False,
attention_dropout=0.0,
decoder_sparse_step=1,
moe_intermediate_size=1408,
num_experts_per_tok=4,
num_experts=60,
norm_topk_prob=True,
mlp_only_layers=None,
rope_scaling=None,
head_dim=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
self.head_dim = head_dim or hidden_size // num_attention_heads
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
# MoE arguments
self.decoder_sparse_step = decoder_sparse_step
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.norm_topk_prob = norm_topk_prob
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class Qwen3VLMoeVisionConfig(PretrainedConfig):
model_type = "qwen3_vl_moe"
base_config_key = "vision_config"
def __init__(
self,
depth=27,
hidden_size=1152,
hidden_act="gelu_pytorch_tanh",
intermediate_size=4304,
num_heads=16,
in_channels=3,
patch_size=16,
spatial_merge_size=2,
temporal_patch_size=2,
out_hidden_size=3584,
num_position_embeddings=2304,
deepstack_visual_indexes=[8, 16, 24],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
self.deepstack_visual_indexes = deepstack_visual_indexes
class Qwen3VLMoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
vision_start_token_id (`int`, *optional*, defaults to 151652):
The start token index to encode the image prompt.
vision_end_token_id (`int`, *optional*, defaults to 151653):
The end token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings.
```python
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
>>> # Initializing a Qwen3-VL-MOE style configuration
>>> configuration = Qwen3VLMoeConfig()
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen3_vl_moe"
sub_configs = {
"vision_config": Qwen3VLMoeVisionConfig,
"text_config": Qwen3VLMoeTextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
__all__ = [
"Qwen3VLMoeConfig",
"Qwen3VLMoeVisionConfig",
"Qwen3VLConfig",
"Qwen3VLVisionConfig",
]

View File

@@ -2,9 +2,19 @@ import logging
import os
from typing import List, Optional
import torch
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
try:
from mf_adapter import TransferEngine
import_error = None
except ImportError as e:
import_error = e
pass
logger = logging.getLogger(__name__)
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
def __init__(
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
):
try:
from mf_adapter import TransferEngine
except ImportError as e:
raise ImportError(
if import_error is not None:
logger.warning(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
) from e
)
raise import_error
self.engine = TransferEngine()
self.hostname = hostname
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
self.initialize()
def initialize(self) -> None:
from sglang.srt.layers.dp_attention import (
get_tensor_model_parallel_world_size,
get_tp_group,
)
transfer_protocol = self._get_transfer_protocol()
if transfer_protocol is None or transfer_protocol == "sdma":
trans_op_type = TransferEngine.TransDataOpType.SDMA
else:
trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
"""with device RDMA for PD transfer"""
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(get_tensor_model_parallel_world_size())
]
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
torch.distributed.all_gather(
output_tensor_list, tmp_tensor, group=get_tp_group().device_group
)
"""Initialize the ascend transfer instance."""
ret_value = self.engine.initialize(
self.store_url,
self.session_id,
self.role,
self.npu_id,
self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
)
if ret_value != 0:
logger.error("Ascend Transfer Engine initialization failed.")
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
ret_value = -1
if ret_value != 0:
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
@staticmethod
def _get_transfer_protocol():
protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
allowed_protocols = {"device_rdma", "sdma"}
if protocol and protocol.lower() in allowed_protocols:
return protocol.lower()
else:
logger.warning(
"Invalid or no transfer protocol specified, using default protocol."
)
return None

View File

@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
socket.connect(endpoint)
return socket
def get_mha_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], List[int], List[int], int]:
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
num_kv_layers = len(src_kv_ptrs) // 2
end_layer = start_layer + num_kv_layers
dst_num_total_layers = len(dst_kv_ptrs) // 2
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
layers_current_pp_stage = len(src_k_ptrs)
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
def get_mla_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
) -> Tuple[List[int], List[int], int]:
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(src_kv_ptrs)
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
layers_current_pp_stage = len(src_kv_ptrs)
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
class CommonKVSender(BaseKVSender):

View File

@@ -609,21 +609,15 @@ class DecodeTransferQueue:
idx = decode_req.metadata_buffer_index
(
output_id,
cached_tokens,
output_token_logprobs_val,
output_token_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
output_topk_p,
output_topk_index,
output_hidden_states,
) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item())
decode_req.req.cached_tokens = cached_tokens[0].item()
if not self.spec_algorithm.is_none():
decode_req.req.output_topk_p = output_topk_p
decode_req.req.output_topk_index = output_topk_index
decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append(
@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
elif prepare_mlp_sync_flag:
batch, _ = self._prepare_idle_batch_and_run(None)
queue_size = (
if batch is None and (
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
queue_size += len(self.decode_offload_manager.ongoing_offload)
if batch is None and queue_size == 0:
== 0
):
self.self_check_during_idle()
self.last_batch = batch
@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
)
self.process_batch_result(tmp_batch, tmp_result)
queue_size = (
if batch is None and (
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
queue_size += len(self.decode_offload_manager.ongoing_offload)
if batch is None and queue_size == 0:
== 0
):
self.self_check_during_idle()
self.last_batch = batch
@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
self.decode_offload_manager.check_offload_progress()

View File

@@ -1,185 +0,0 @@
import logging
import threading
import time
import torch
from sglang.srt.server_args import ServerArgs
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
)
logger = logging.getLogger(__name__)
class DecodeKVCacheOffloadManager:
"""Manage decode-side KV cache offloading lifecycle and operations."""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tp_group: torch.distributed.ProcessGroup,
tree_cache: BasePrefixCache,
server_args: ServerArgs,
) -> None:
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = server_args.page_size
self.server_args = server_args
self.request_counter = 0
self.tree_cache = tree_cache
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
if isinstance(kv_cache, MHATokenToKVPool):
self.decode_host_mem_pool = MHATokenToKVPoolHost(
kv_cache,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
)
elif isinstance(kv_cache, MLATokenToKVPool):
self.decode_host_mem_pool = MLATokenToKVPoolHost(
kv_cache,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
)
else:
raise ValueError("Unsupported KV cache type for decode offload")
self.tp_group = tp_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
mem_pool_host=self.decode_host_mem_pool,
page_size=self.page_size,
tp_group=tp_group,
io_backend=server_args.hicache_io_backend,
load_cache_event=threading.Event(),
storage_backend=server_args.hicache_storage_backend,
model_name=server_args.served_model_name,
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
)
self.ongoing_offload = {}
self.ongoing_backup = {}
logger.info("Enable offload kv cache for decode side")
def offload_kv_cache(self, req) -> bool:
"""Offload a finished request's KV cache to storage."""
if self.cache_controller is None or self.decode_host_mem_pool is None:
return False
if req.req_pool_idx == -1:
return False
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
if token_indices.dim() == 0 or token_indices.numel() == 0:
logger.debug(
f"Request {req.rid} has invalid token_indices: {token_indices}"
)
return False
tokens = req.origin_input_ids + req.output_ids
aligned_len = (len(tokens) // self.page_size) * self.page_size
if aligned_len == 0:
return False
token_indices = token_indices[:aligned_len]
tokens = tokens[:aligned_len]
# Asynchronously offload KV cache from device to host by cache controller
self.request_counter += 1
ack_id = self.request_counter
host_indices = self.cache_controller.write(
device_indices=token_indices.long(),
node_id=ack_id,
)
if host_indices is None:
logger.error(f"Not enough host memory for request {req.rid}")
return False
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
return True
def check_offload_progress(self):
"""Check the progress of offload from device to host and backup from host to storage."""
cc = self.cache_controller
qsizes = torch.tensor(
[
len(cc.ack_write_queue),
cc.ack_backup_queue.qsize(),
],
dtype=torch.int,
)
if self.tp_world_size > 1:
torch.distributed.all_reduce(
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
)
n_write, n_backup = map(int, qsizes.tolist())
self._check_offload_progress(n_write)
self._check_backup_progress(n_backup)
def _check_offload_progress(self, finish_count):
"""Check the progress of offload from device to host."""
while finish_count > 0:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize()
for ack_id in ack_list:
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
# Release device
self.tree_cache.cache_finished_req(req)
# Trigger async backup from host to storage by cache controller
self._trigger_backup(req.rid, host_indices, tokens, start_time)
finish_count -= 1
def _check_backup_progress(self, finish_count):
"""Check the progress of backup from host to storage."""
for _ in range(finish_count):
storage_operation = self.cache_controller.ack_backup_queue.get()
ack_id = storage_operation.id
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
# Release host memory
self.decode_host_mem_pool.free(host_indices)
logger.debug(
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
)
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
"""Trigger async backup from host to storage by cache controller."""
# Generate page hashes and write to storage
page_hashes = self._compute_prefix_hash(tokens)
ack_id = self.cache_controller.write_storage(
host_indices,
tokens,
hash_value=page_hashes,
)
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
def _compute_prefix_hash(self, tokens):
last_hash = ""
page_hashes = []
for offset in range(0, len(tokens), self.page_size):
page_tokens = tokens[offset : offset + self.page_size]
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
page_hashes.append(last_hash)
return page_hashes

View File

@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
req.grammar.finished = req.finished()
self.output_ids = torch.tensor(self.output_ids, device=self.device)
# Simulate the eagle run.
if self.spec_algorithm.is_eagle():
# Simulate the eagle run. We add mock data to hidden states for the
# ease of implementation now meaning the first token will have acc rate
# of 0.
if not self.spec_algorithm.is_none():
b = len(self.reqs)
topk = server_args.speculative_eagle_topk
topk_p = torch.stack(
[
torch.as_tensor(
req.output_topk_p[:topk],
device=self.device,
dtype=torch.float32,
)
for req in self.reqs
],
dim=0,
topk_p = torch.arange(
b * server_args.speculative_eagle_topk,
0,
-1,
device=self.device,
dtype=torch.float32,
)
topk_index = torch.stack(
[
torch.as_tensor(
req.output_topk_index[:topk],
device=self.device,
dtype=torch.int64,
)
for req in self.reqs
],
dim=0,
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
topk_p /= b * server_args.speculative_eagle_topk
topk_index = torch.arange(
b * server_args.speculative_eagle_topk, device=self.device
)
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)

View File

@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
layers_params = None
# pp is not supported on the decode side yet
start_layer = self.kv_args.prefill_start_layer
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
if self.is_mla_backend:
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
)
src_kv_ptrs = self.kv_args.kv_data_ptrs
layers_per_pp_stage = len(src_kv_ptrs)
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
dst_kv_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_current_pp_stage)
for layer_id in range(layers_per_pp_stage)
]
else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
)
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
kv_item_len = self.kv_args.kv_item_lens[0]
layers_params = [
(
@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
dst_k_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_current_pp_stage)
for layer_id in range(layers_per_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
kv_item_len,
)
for layer_id in range(layers_current_pp_stage)
for layer_id in range(layers_per_pp_stage)
]
assert layers_params is not None
@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
)
# pp is not supported on the decode side yet
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
dst_num_total_layers = num_kv_layers * self.pp_size
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
layers_per_pp_stage = len(src_k_ptrs)
start_layer = self.pp_rank * layers_per_pp_stage
end_layer = start_layer + layers_per_pp_stage
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
]
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_current_pp_stage)
for layer_id in range(layers_per_pp_stage)
] + [
(
src_v_ptrs[layer_id],
@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_current_pp_stage)
for layer_id in range(layers_per_pp_stage)
]
def process_layer_tp_aware(layer_params):

View File

@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
req.output_topk_p = batch.spec_info.topk_p[i]
req.output_topk_index = batch.spec_info.topk_index[i]
if self.spec_algorithm.is_eagle3():
req.hidden_states_tensor = (
batch.spec_info.hidden_states[i].cpu().clone()

View File

@@ -85,7 +85,7 @@ class MetadataBuffers:
self,
size: int,
hidden_size: int,
hidden_states_dtype: torch.dtype,
dtype: torch.dtype,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
@@ -107,9 +107,7 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.cached_tokens = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
@@ -122,49 +120,33 @@ class MetadataBuffers:
self.output_top_logprobs_idx = torch.zeros(
(size, max_top_logprobs_num), dtype=torch.int32, device=device
)
# For PD + spec decode
self.output_topk_p = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
self.output_topk_index = torch.zeros(
(size, 16), dtype=torch.int64, device=device
)
self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=hidden_states_dtype, device=device
(size, hidden_size), dtype=dtype, device=device
)
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.cached_tokens.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
self.output_topk_p.data_ptr(),
self.output_topk_index.data_ptr(),
self.output_hidden_states.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.cached_tokens.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
self.output_topk_p.nbytes,
self.output_topk_index.nbytes,
self.output_hidden_states.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.cached_tokens[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
self.output_top_logprobs_idx[0].nbytes,
self.output_topk_p[0].nbytes,
self.output_topk_index[0].nbytes,
self.output_hidden_states[0].nbytes,
]
return ptrs, data_lens, item_lens
@@ -172,20 +154,16 @@ class MetadataBuffers:
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.cached_tokens[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
self.output_top_logprobs_idx[idx],
self.output_topk_p[idx],
self.output_topk_index[idx],
self.output_hidden_states[idx],
)
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -208,17 +186,8 @@ class MetadataBuffers:
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
# For PD + spec decode
# for PD + spec decode
if req.hidden_states_tensor is not None:
# speculative_eagle_topk should not be greater than 16 currently
topk = req.output_topk_p.size(0)
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
req.output_topk_p
)
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
req.output_topk_index
)
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)

View File

@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version(
"sgl-kernel",
"0.3.12",
"0.3.11",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)

View File

@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import asyncio
import copy
import dataclasses
import logging
import os
@@ -12,8 +11,7 @@ import signal
import sys
import threading
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import grpc
import zmq
@@ -81,10 +79,11 @@ class GrpcReqState:
last_completion_tokens: int = 1
# Streaming state
last_output_offset: int = 0
stream_finished: bool = False
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
# Token accumulation (for non-streaming)
# Output accumulation
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
@@ -140,6 +139,8 @@ class GrpcRequestManager:
self.is_pause_cond = asyncio.Condition()
# Metrics
self.request_counter = 0
self.request_counter_lock = asyncio.Lock()
self.last_receive_tstamp = time.time()
# Crash dump for debugging
@@ -157,133 +158,22 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
) -> asyncio.Queue:
"""
Submit a generation request to the scheduler with n>1 parallel sampling support.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
Submit a generation request to the scheduler.
Returns a queue for streaming outputs.
"""
n = getattr(obj.sampling_params, "n", 1)
if n <= 1:
async for response in self._handle_single_request(
obj, request_id, grpc_context
):
yield response
return
# N>1 handling - two-phase approach
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
# Generate base request ID if not provided
if request_id is None:
base_request_id = f"grpc-{uuid.uuid4().hex}"
else:
base_request_id = request_id
# Phase 1: Cache the common prefix
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
prefix_obj = copy.copy(obj)
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
# Send prefix caching request and consume response
async for _ in self._handle_single_request(
prefix_obj, f"{base_request_id}-prefix", grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
# Phase 2: Generate n parallel requests
logger.debug(f"Phase 2: Generating {n} parallel requests")
generators = []
request_ids = []
for i in range(n):
# Create individual generation request
gen_obj = copy.copy(obj)
gen_obj.sampling_params = copy.copy(obj.sampling_params)
gen_obj.sampling_params.n = 1 # Each request generates 1 response
gen_request_id = f"{base_request_id}-{i}"
request_ids.append(gen_request_id)
# Start generation request
generators.append(
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
)
# Handle response aggregation
is_stream = getattr(obj, "stream", False)
if not is_stream:
# Non-streaming: collect all responses and return as batch
logger.debug(f"Non-streaming mode: collecting {n} responses")
responses = []
for generator in generators:
async for response in generator:
responses.append(response)
yield responses # Return all responses as a batch
else:
# Streaming mode: multiplex responses with index for ordering
logger.debug(f"Streaming mode: multiplexing {n} streams")
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
# Create async tasks for all generators
task_map = {}
for generator in generators:
task = asyncio.create_task(generator.__anext__())
task_map[task] = generator
# Process responses as they arrive
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
generator = task_map.pop(task)
try:
response = await task
# Add index for client-side ordering
if isinstance(response, dict) and "meta_info" in response:
response_rid = response["meta_info"].get("id", "")
if response_rid in rid_to_index:
response["index"] = rid_to_index[response_rid]
yield response
# Create next task for this generator
next_task = asyncio.create_task(generator.__anext__())
task_map[next_task] = generator
except StopAsyncIteration:
# This generator is finished
pass
async def _handle_single_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided
if request_id is None:
request_id = f"grpc-{uuid.uuid4().hex}"
async with self.request_counter_lock:
request_id = f"grpc-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id
# Create and register request state
# TODO: support log_request
# Create request state
state = GrpcReqState(
request_id=request_id,
grpc_context=grpc_context,
@@ -299,51 +189,19 @@ class GrpcRequestManager:
state.session_id = obj.session_params.session_id
state.is_session_request = True
# Register state
self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj)
# Send to scheduler via ZMQ
try:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await self._send_to_scheduler(obj)
is_stream = getattr(obj, "stream", False)
while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
if is_stream:
yield response
# Non-streaming: yield final response with accumulated tokens from state
if isinstance(response, dict) and response.get("finished", False):
if not is_stream:
final_response = response.copy()
final_response["token_ids"] = state.output_ids
yield final_response
break
except asyncio.TimeoutError:
# Timeout waiting for response - abort and cleanup
logger.warning(
f"Timeout waiting for response for request {request_id}"
)
await self.abort_request(request_id)
return
finally:
# Always clean up request state when exiting
self._cleanup_request_state(request_id)
def _cleanup_request_state(self, request_id: str):
"""Clean up local request state (does not notify scheduler)."""
if request_id in self.rid_to_state:
except Exception as e:
# Clean up on failure
del self.rid_to_state[request_id]
raise RuntimeError(f"Failed to send request to scheduler: {e}")
return state.out_queue
async def embedding_request(
self,
@@ -356,7 +214,9 @@ class GrpcRequestManager:
"""
# Generate request ID if not provided
if request_id is None:
request_id = f"grpc-embed-{uuid.uuid4().hex}"
async with self.request_counter_lock:
request_id = f"grpc-embed-{self.request_counter}"
self.request_counter += 1
obj.rid = request_id
@@ -495,6 +355,7 @@ class GrpcRequestManager:
# Extract output for this request
output_data = {
"request_id": rid,
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None,
"meta_info": {
@@ -506,9 +367,6 @@ class GrpcRequestManager:
if batch_out.completion_tokens
else 0
),
"cached_tokens": (
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": (
str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i]
@@ -517,110 +375,29 @@ class GrpcRequestManager:
},
}
# Accumulate input logprobs (only once, usually in first chunk)
if batch_out.input_token_logprobs_val and i < len(
batch_out.input_token_logprobs_val
):
if not state.input_token_logprobs_val:
state.input_token_logprobs_val.extend(
batch_out.input_token_logprobs_val[i]
)
if batch_out.input_token_logprobs_idx and i < len(
batch_out.input_token_logprobs_idx
):
state.input_token_logprobs_idx.extend(
batch_out.input_token_logprobs_idx[i]
)
if batch_out.input_top_logprobs_val and i < len(
batch_out.input_top_logprobs_val
):
state.input_top_logprobs_val.extend(
batch_out.input_top_logprobs_val[i]
)
if batch_out.input_top_logprobs_idx and i < len(
batch_out.input_top_logprobs_idx
):
state.input_top_logprobs_idx.extend(
batch_out.input_top_logprobs_idx[i]
)
# Send input logprobs based on mode
if state.input_token_logprobs_val:
if state.obj.stream and not state.input_logprobs_sent:
# Streaming: send input logprobs once in first chunk that has them
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
state.input_logprobs_sent = True
elif not state.obj.stream and output_data["finished"]:
# Non-streaming: send input logprobs in final chunk
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
# Add output logprobs if available (RAW - no detokenization!)
# Add logprobs if available
if batch_out.output_token_logprobs_val and i < len(
batch_out.output_token_logprobs_val
):
# Accumulate in state first
state.output_token_logprobs_val.extend(
batch_out.output_token_logprobs_val[i]
)
if batch_out.output_token_logprobs_idx and i < len(
batch_out.output_token_logprobs_idx
):
state.output_token_logprobs_idx.extend(
batch_out.output_token_logprobs_idx[i]
)
if batch_out.output_top_logprobs_val and i < len(
batch_out.output_top_logprobs_val
):
state.output_top_logprobs_val.extend(
output_data["logprobs"] = {
"tokens": batch_out.output_token_logprobs_val[i],
"top_logprobs": (
batch_out.output_top_logprobs_val[i]
)
if batch_out.output_top_logprobs_idx and i < len(
batch_out.output_top_logprobs_idx
):
state.output_top_logprobs_idx.extend(
batch_out.output_top_logprobs_idx[i]
)
if batch_out.output_top_logprobs_val
and i < len(batch_out.output_top_logprobs_val)
else None
),
}
if state.obj.stream:
# For streaming: send incremental logprobs (only new tokens in this chunk)
# NOTE: this is different than TokenizerManager, which always accumulates
def get_part(attr_name):
source_list = getattr(batch_out, attr_name, None)
return (
source_list[i]
if source_list and i < len(source_list)
else []
)
# Update state
if output_data["text"]:
state.text += output_data["text"][state.last_output_offset :]
state.last_output_offset = len(output_data["text"])
output_data["output_logprobs"] = {
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
"top_logprobs_val": get_part("output_top_logprobs_val"),
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
}
elif output_data["finished"]:
# Non-streaming: send cumulative output logprobs in final chunk
output_data["output_logprobs"] = {
"token_logprobs_val": state.output_token_logprobs_val,
"token_logprobs_idx": state.output_token_logprobs_idx,
"top_logprobs_val": state.output_top_logprobs_val,
"top_logprobs_idx": state.output_top_logprobs_idx,
}
# Update state for accumulation
if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"])
# Send to output queue
await state.out_queue.put(output_data)
# Handle completion

View File

@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request)
# Submit to request manager (automatically handles n>1)
response_generator = self.request_manager.generate_request(
# Submit to request manager
output_queue = await self.request_manager.generate_request(
obj=tokenized_req,
request_id=request.request_id,
grpc_context=context,
)
async for output in response_generator:
# Handle batch responses (for n>1 non-streaming)
if isinstance(output, list):
for batch_output in output:
if "error" in batch_output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=batch_output["error"],
http_status_code=(
"500" if "abort" not in batch_output else "499"
),
),
)
else:
# All non-error batch outputs are final responses
yield self._create_completion_response(
request.request_id, batch_output
)
else:
# Handle single response (for streaming or n=1 non-streaming)
# Stream outputs
while True:
try:
# Get output with timeout
output = await asyncio.wait_for(output_queue.get(), timeout=4)
# Check for errors
if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
),
),
)
elif output.get("finished", False):
break
# Check if finished
if output.get("finished", False):
# Send completion
yield self._create_completion_response(
request.request_id, output
)
break
else:
# Send chunk
yield self._create_chunk_response(request.request_id, output)
except asyncio.TimeoutError:
# Check if context is still active
if context.cancelled():
# Abort the request
await self.request_manager.abort_request(request.request_id)
break
continue
except Exception as e:
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
yield sglang_scheduler_pb2.GenerateResponse(
@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
prompt_tokens=result.get("prompt_tokens", 0),
cached_tokens=0,
embedding_dim=len(result["embedding"]),
generation_time=time.time() - self.start_time,
),
)
@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger.info(f"Sending health check request to request manager...")
# Submit and wait for response
output_generator = self.request_manager.generate_request(
output_queue = await self.request_manager.generate_request(
health_request, request_id=rid
)
try:
# Get first response with timeout
# Wait for response with configurable timeout
response = await asyncio.wait_for(
output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
)
# Clean up
@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
return_logprob=grpc_req.return_logprob,
logprob_start_len=grpc_req.logprob_start_len or -1,
top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=grpc_req.stream or False,
lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
stream=True, # Always stream for gRPC
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
),
@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
regex = None
json_schema = None
ebnf_grammar = None
structural_tag = None
if grpc_params.HasField("regex"):
regex = grpc_params.regex
@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
json_schema = grpc_params.json_schema
elif grpc_params.HasField("ebnf_grammar"):
ebnf_grammar = grpc_params.ebnf_grammar
elif grpc_params.HasField("structural_tag"):
structural_tag = grpc_params.structural_tag
return SGLSamplingParams(
temperature=grpc_params.temperature or 1.0,
@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
repetition_penalty=grpc_params.repetition_penalty or 1.0,
max_new_tokens=grpc_params.max_new_tokens or 128,
min_new_tokens=grpc_params.min_new_tokens or 0,
stop=list(grpc_params.stop) if grpc_params.stop else [],
stop=list(grpc_params.stop) if grpc_params.stop else None,
stop_token_ids=(
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
),
skip_special_tokens=grpc_params.skip_special_tokens,
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
regex=regex,
json_schema=json_schema,
ebnf=ebnf_grammar,
structural_tag=structural_tag,
n=grpc_params.n or 1,
ignore_eos=grpc_params.ignore_eos,
)
def _convert_logprobs_to_proto(
self, logprobs_data: Dict
) -> Optional[sglang_scheduler_pb2.LogProbs]:
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
if not logprobs_data:
return None
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
# Build TopLogProbs entries
top_logprobs_proto = []
if top_logprobs_val and top_logprobs_idx:
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
top_logprobs_proto.append(
sglang_scheduler_pb2.TopLogProbs(
values=val_list,
token_ids=idx_list,
)
)
return sglang_scheduler_pb2.LogProbs(
token_logprobs=token_logprobs_val,
token_ids=token_logprobs_idx,
top_logprobs=top_logprobs_proto,
)
def _create_chunk_response(
self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a streaming chunk response."""
meta_info = output.get("meta_info", {})
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present (only in first chunk)
input_logprobs_proto = self._convert_logprobs_to_proto(
output.get("input_logprobs")
)
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
token_ids=output.get("token_ids", []),
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get("completion_tokens", 0),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
text=output.get("text", ""),
prompt_tokens=0,
completion_tokens=len(output.get("token_ids", [])),
cached_tokens=0,
generation_time=time.time() - self.start_time,
queue_time=0.0,
),
)
@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response."""
# Extract meta info and finish reason details
# Determine finish reason
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
meta_info = output.get("meta_info", {})
finish_reason_data = meta_info.get("finish_reason")
# Determine finish reason, default is stop
finish_reason = "stop"
if finish_reason_data:
if isinstance(finish_reason_data, dict):
finish_reason_type = finish_reason_data.get("type")
else:
# Handle legacy string format
finish_reason_type = finish_reason_data
if finish_reason_type == "length":
finish_reason = "length"
elif finish_reason_type == "abort":
finish_reason = "abort"
# Extract matched_stop information
matched_stop_kwargs = {}
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
matched = finish_reason_data["matched"]
if isinstance(matched, int):
matched_stop_kwargs["matched_token_id"] = matched
elif isinstance(matched, str):
matched_stop_kwargs["matched_stop_str"] = matched
# Convert output logprobs if present
output_logprobs_proto = self._convert_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present
input_logprobs_proto = self._convert_logprobs_to_proto(
output.get("input_logprobs")
)
if meta_info.get("finish_reason") == "length":
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
elif meta_info.get("finish_reason") == "eos_token":
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
complete=sglang_scheduler_pb2.GenerateComplete(
output_ids=output.get("token_ids", []),
output_text=output.get("text", ""),
finish_reason=finish_reason,
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get(
"completion_tokens", len(output.get("token_ids", []))
),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
**matched_stop_kwargs,
),
)

View File

@@ -16,7 +16,7 @@
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
from typing import Any, Dict, List, Optional, TypeAlias, Union
from openai.types.responses import (
ResponseFunctionToolCall,
@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For custom metric labels
custom_labels: Optional[Dict[str, str]] = None
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
@field_validator("max_tokens")
@classmethod
@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
arguments: Optional[str | Dict[str, Any]] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
@@ -392,7 +388,7 @@ class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: str
name: Optional[str] = None
parameters: Optional[object] = None
strict: bool = False
@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
)
priority: int = Field(default=0, description="Request priority")
extra_key: Optional[str] = Field(
default=None,
description="Extra key for classifying the request (e.g. cache_salt)",
)
cache_salt: Optional[str] = Field(
default=None, description="Cache salt for request caching"
)
# SGLang-specific sampling parameters
frequency_penalty: float = 0.0
@@ -943,16 +928,6 @@ class MessageProcessingResult:
tool_call_constraint: Optional[Any] = None
class ToolCallProcessingResult(NamedTuple):
"""Result of processing tool calls in a response."""
tool_calls: Optional[
List[Any]
] # List of ToolCall objects or None if parsing failed
remaining_text: str # Text remaining after parsing tool calls
finish_reason: Dict[str, Any] # Updated finish reason dictionary
class ResponseReasoningTextContent(BaseModel):
text: str
type: Literal["reasoning_text"] = "reasoning_text"

View File

@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
self.tokenizer_manager = tokenizer_manager
self.allowed_custom_labels = (
set(
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
)
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
else None
)
@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
return self.create_error_response(
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
)
except ValueError as e:
return self.create_error_response(
message=str(e),
err_type="BadRequest",
status_code=400,
)
except Exception as e:
logger.exception(f"Error in request: {e}")
return self.create_error_response(
@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
parts = []
for key in ["cache_salt", "extra_key"]:
value = getattr(request, key, None)
if value:
if not isinstance(value, str):
raise TypeError(
f"Value of {key} must be a string, but got {type(value).__name__}"
)
parts.append(value)
return "".join(parts) if parts else None
@abstractmethod
def _convert_to_internal_request(
self,
@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
)
return json.dumps({"error": error.model_dump()})
def extract_custom_labels(self, raw_request):
def extract_customer_labels(self, raw_request):
if (
not self.allowed_custom_labels
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
):
return None
custom_labels = None
customer_labels = None
header = (
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
)
@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
raw_labels = None
if isinstance(raw_labels, dict):
custom_labels = {
customer_labels = {
label: value
for label, value in raw_labels.items()
if label in self.allowed_custom_labels
}
return custom_labels
return customer_labels

View File

@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from jsonschema import Draft202012Validator, SchemaError
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
LogProbs,
MessageProcessingResult,
ToolCall,
ToolCallProcessingResult,
ToolChoice,
TopLogprob,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.function_call.utils import get_json_schema_constraint
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
super().__init__(tokenizer_manager)
self.template_manager = template_manager
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
def _request_id_prefix(self) -> str:
return "chatcmpl-"
@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
):
return "Tools cannot be empty if tool choice is set to required."
if request.tool_choice is not None and not isinstance(request.tool_choice, str):
if not request.tools:
return "Tools cannot be empty if tool choice is set to a specific tool."
tool_name = request.tool_choice.function.name
tool_exists = any(tool.function.name == tool_name for tool in request.tools)
if not tool_exists:
return f"Tool '{tool_name}' not found in tools list."
# Validate tool definitions
for i, tool in enumerate(request.tools or []):
if tool.function.parameters is None:
continue
try:
Draft202012Validator.check_schema(tool.function.parameters)
except SchemaError as e:
return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
max_output_tokens = request.max_completion_tokens or request.max_tokens
server_context_length = self.tokenizer_manager.server_args.context_length
if (
@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
else:
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
# Extract custom labels from raw request headers
custom_labels = self.extract_custom_labels(raw_request)
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority,
custom_labels=custom_labels,
customer_labels=customer_labels,
)
return adapted_request, request
@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
# Handle JSON schema constraint directly for required or named tool choice
if request.tool_choice == "required" or isinstance(
request.tool_choice, ToolChoice
):
json_schema = get_json_schema_constraint(
request.tools, request.tool_choice
)
tool_call_constraint = ("json_schema", json_schema)
# Use chat template
if self.template_manager.chat_template_name is None:
@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
)
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
stream_buffers[index] = stream_buffer + delta
# Handle reasoning content
if self.reasoning_parser and request.separate_reasoning:
if (
self.tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning
):
reasoning_text, delta = self._process_reasoning_stream(
index, delta, reasoning_parser_dict, content, request
)
@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
# Handle reasoning content
reasoning_text = None
reasoning_parser = self.reasoning_parser
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
if reasoning_parser and request.separate_reasoning:
is_force_reasoning = (
self.template_manager.force_reasoning
@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
and request.tools
and self.tool_call_parser
):
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
tool_calls, text, finish_reason = self._process_tool_calls(
text,
request.tools,
finish_reason,
request.tool_choice,
history_tool_calls_cnt,
text, request.tools, finish_reason
)
choice_data = ChatCompletionResponseChoice(
@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
return ChoiceLogprobs(content=token_logprobs)
def _process_tool_call_id(
self,
call_item: ToolCallItem,
history_tool_calls_cnt: int,
) -> str:
"""Process for generating a new and unique `tool_call_id`"""
if self.tool_call_parser != "kimi_k2":
# A simple uuid is sufficient for all models except for Kimi-K2.
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
return tool_call_id
else:
# Align with Kimi-K2 format: functions.{name}:{index}
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
logger.debug(
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
)
return tool_call_id
def _process_tool_calls(
self,
text: str,
tools: List[Any],
finish_reason: Dict[str, Any],
tool_choice: Optional[Union[str, ToolChoice]] = None,
history_tool_calls_cnt: int = 0,
) -> ToolCallProcessingResult:
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
"""Process tool calls in the response"""
# Handle required or named tool choice
if tool_choice == "required" or (
isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
):
# Set finish reason to tool_calls since we're processing tool calls
if finish_reason["type"] == "stop":
finish_reason["type"] = "tool_calls"
finish_reason["matched"] = None
try:
# For required tool choice, we expect a JSON array of tool calls
tool_call_data = json.loads(text)
tool_calls = []
for i, tool in enumerate(tool_call_data):
# Create a ToolCallItem from the JSON data
call_info = ToolCallItem(
tool_index=i, # Use the loop index as tool_index
name=tool["name"],
parameters=json.dumps(tool["parameters"], ensure_ascii=False),
)
tool_id = self._process_tool_call_id(
call_info, history_tool_calls_cnt
)
tool_calls.append(
ToolCall(
id=tool_id,
index=i,
function=FunctionResponse(
name=tool["name"],
arguments=json.dumps(
tool["parameters"], ensure_ascii=False
),
),
)
)
return ToolCallProcessingResult(tool_calls, "", finish_reason)
except json.JSONDecodeError as e:
logger.error(f"Tool call parsing error: {e}")
return ToolCallProcessingResult(None, text, finish_reason)
# Use parser since output is not constrained by JSON schema
parser = FunctionCallParser(tools, self.tool_call_parser)
if parser.has_tool_call(text):
if finish_reason["type"] == "stop":
@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
text, call_info_list = parser.parse_non_stream(text)
tool_calls = []
for call_info in call_info_list:
tool_id = self._process_tool_call_id(
call_info, history_tool_calls_cnt
)
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
if (
self.tool_call_parser == "kimi_k2"
and call_info.name is not None
):
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
else:
tool_id = f"call_{uuid.uuid4().hex[:24]}"
tool_calls.append(
ToolCall(
id=tool_id,
@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
),
)
)
return ToolCallProcessingResult(tool_calls, text, finish_reason)
return tool_calls, text, finish_reason
except Exception as e:
logger.error(f"Tool call parsing error: {e}")
# Return error but don't fail the whole request
return ToolCallProcessingResult(None, text, finish_reason)
return None, text, finish_reason
return ToolCallProcessingResult(None, text, finish_reason)
return None, text, finish_reason
def _process_streaming_logprobs(
self, content: Dict[str, Any], n_prev_token: int
@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
or self._get_enable_thinking_from_request(request)
)
reasoning_parser_dict[index] = ReasoningParser(
self.reasoning_parser,
self.tokenizer_manager.server_args.reasoning_parser,
request.stream_reasoning,
is_force_reasoning,
)
reasoning_parser = reasoning_parser_dict[index]
return reasoning_parser.parse_stream_chunk(delta)
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
"""Counts the number of tool calls in the request's message history.
NOTE: This method is only useful for models that include self-increasing
history tool call idx in tool calls id, such as kimi-k2
Args:
request: The chat completion request object.
Returns:
The total number of tool calls in the history, or 0 if not applicable.
"""
messages = getattr(request, "messages", [])
idx = 0
for msg in messages:
if msg.role == "assistant":
tool_calls = getattr(msg, "tool_calls", None)
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
"""
if hasattr(request, "chat_template_kwargs") and request.chat_template_kwargs:
# For Qwen3 models, `enable_thinking` is supported.
if self.reasoning_parser in ["qwen3", "glm45"]:
return request.chat_template_kwargs.get("enable_thinking", False)
if request.chat_template_kwargs.get("enable_thinking") is not None:
return request.chat_template_kwargs.get("enable_thinking")
# For DeepSeek-V3.1 models, `thinking` is supported.
elif self.reasoning_parser in ["deepseek-v3"]:
return request.chat_template_kwargs.get("thinking", False)
elif request.chat_template_kwargs.get("thinking") is not None:
return request.chat_template_kwargs.get("thinking")
else:
return False
return False
@@ -1068,25 +952,13 @@ class OpenAIServingChat(OpenAIServingBase):
):
"""Process tool calls in streaming response"""
if index not in parser_dict:
# Use JSON detector directly for required or named tool choice
if request.tool_choice == "required" or isinstance(
request.tool_choice, ToolChoice
):
parser_dict[index] = JsonArrayParser()
else:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=self.tool_call_parser,
)
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=self.tool_call_parser,
)
parser = parser_dict[index]
# Handle both FunctionCallParser and JsonArrayParser
if isinstance(parser, JsonArrayParser):
result = parser.parse_streaming_increment(delta, request.tools)
normal_text, calls = result.normal_text, result.calls
else:
normal_text, calls = parser.parse_stream_chunk(delta)
normal_text, calls = parser.parse_stream_chunk(delta)
# Yield normal text
if normal_text:
@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
yield f"data: {chunk.model_dump_json()}\n\n"
# Yield tool calls
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
for call_item in calls:
# Mark that this choice has tool calls
has_tool_calls[index] = True
@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
# Tool call ID should be generated only once per tool call
if call_item.name:
# First chunk: include ID and function name
tool_call_id = self._process_tool_call_id(
call_item, history_tool_calls_cnt
)
if self.tool_call_parser == "kimi_k2":
# Align with Kimi-K2 format: functions.{name}:{index}
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
else:
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas
@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
def _check_for_unstreamed_tool_args(
self,
parser: Union[FunctionCallParser, JsonArrayParser],
parser: FunctionCallParser,
content: Dict[str, Any],
request: ChatCompletionRequest,
index: int,
@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
when generation finishes. This ensures tool calls are properly completed
even if the model generates the final arguments in the last chunk.
"""
# Get the detector - either from FunctionCallParser or directly if json detector
detector = parser.detector if hasattr(parser, "detector") else parser
# Only check if we have tool calls and the detector has tracked data
# Only check if we have tool calls and the parser has tracked data
if (
not hasattr(detector, "prev_tool_call_arr")
or not detector.prev_tool_call_arr
not hasattr(parser.detector, "prev_tool_call_arr")
or not parser.detector.prev_tool_call_arr
):
return None
if (
not hasattr(detector, "streamed_args_for_tool")
or not detector.streamed_args_for_tool
not hasattr(parser.detector, "streamed_args_for_tool")
or not parser.detector.streamed_args_for_tool
):
return None
# Get the last tool call that was being processed
tool_index = len(detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):
tool_index = len(parser.detector.prev_tool_call_arr) - 1
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
return None
# Get expected vs actual arguments
expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {})
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
"arguments", {}
)
expected_call = json.dumps(expected_args, ensure_ascii=False)
actual_call = detector.streamed_args_for_tool[tool_index]
actual_call = parser.detector.streamed_args_for_tool[tool_index]
# Check if there are remaining arguments to send
remaining_call = (

View File

@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
else:
prompt_kwargs = {"input_ids": prompt}
# Extract custom labels from raw request headers
custom_labels = self.extract_custom_labels(raw_request)
# Extract customer labels from raw request headers
customer_labels = self.extract_customer_labels(raw_request)
adapted_request = GenerateReqInput(
**prompt_kwargs,
@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
rid=request.rid,
extra_key=self._compute_extra_key(request),
priority=request.priority,
custom_labels=custom_labels,
customer_labels=customer_labels,
)
return adapted_request, request

View File

@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params,
stream=request.stream,
rid=request.request_id,
extra_key=self._compute_extra_key(request),
background=request.background,
)
@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
sampling_params=sampling_params,
stream=adapted_request.stream,
rid=request_id,
extra_key=adapted_request.extra_key,
return_logprob=adapted_request.return_logprob,
logprob_start_len=adapted_request.logprob_start_len,
top_logprobs_num=adapted_request.top_logprobs_num,

View File

@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map(
server_args=server_args,
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
# TODO optimize performance (rewrite and/or run in separate process with overlap)
def compute_logical_to_rank_dispatch_physical_map(
server_args: ServerArgs,
logical_to_all_physical_map: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
):
r = random.Random(seed)
num_local_gpu_physical_experts = num_physical_experts // num_gpus
num_gpus_per_node = server_args.ep_size // server_args.nnodes
num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype
@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_gpu_physical_experts
physical_expert_id, num_local_physical_experts
)
== gpu_id
]
if len(same_gpu_physical_expert_ids) > 0:
# 1. Prefer same-GPU experts
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
else:
# 2. Otherwise, prefer same-node experts
node_id = gpu_id // num_gpus_per_node
same_node_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_node_id_of_physical_expert(
physical_expert_id, num_local_node_physical_experts
)
== node_id
]
if len(same_node_physical_expert_ids) > 0:
output_partial[gpu_id] = same_node_physical_expert_ids[0]
# 3. Fill remaining slots with fair random choices
num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
def _compute_gpu_id_of_physical_expert(
physical_expert_id: int, num_local_gpu_physical_experts: int
physical_expert_id: int, num_local_physical_experts: int
) -> int:
return physical_expert_id // num_local_gpu_physical_experts
def _compute_node_id_of_physical_expert(
physical_expert_id: int, num_local_host_physical_experts: int
) -> int:
return physical_expert_id // num_local_host_physical_experts
return physical_expert_id // num_local_physical_experts
def _fair_choices(arr: List, k: int, r: random.Random) -> List:

View File

@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.function_call.step3_detector import Step3Detector
from sglang.srt.function_call.utils import get_json_schema_constraint
logger = logging.getLogger(__name__)
@@ -179,8 +178,8 @@ class FunctionCallParser:
strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
json_schema = get_json_schema_constraint(self.tools, tool_choice)
return ("json_schema", json_schema)
ebnf = self.get_ebnf(tool_choice)
return ("ebnf", ebnf) if ebnf is not None else None
def get_ebnf(
self, tool_choice: Union[ToolChoice, Literal["required"]]

View File

@@ -39,7 +39,7 @@ def parse_arguments(json_value):
class Glm4MoeDetector(BaseFormatDetector):
"""
Detector for GLM-4.5 and GLM-4.6 models.
Detector for GLM-4.5 models.
Assumes function call format:
<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>
"""
@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
"""Check if the text contains a glm-4.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.
Streaming incremental parsing tool calls for GLM-4.5 format.
"""
self._buffer += new_text
current_text = self._buffer

View File

@@ -1,63 +0,0 @@
import json
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult
class JsonArrayParser(BaseFormatDetector):
"""
Parser for JSON array tool calls when JSON schema constraints are active.
This parser is used when tool_choice="required" or a specific tool is named,
bypassing model-specific parsers in favor of direct JSON array parsing.
"""
def __init__(self):
super().__init__()
# Configure for JSON array parsing
self.bot_token = "["
self.eot_token = "]"
self.tool_call_separator = ","
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains a JSON tool call (array or single object).
"""
return "[" in text or "{" in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parse JSON tool calls using the base class implementation.
"""
raise NotImplementedError(
"Detect and parse not supported for JSON schema constraints."
)
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build an EBNF grammar for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise NotImplementedError(
"EBNF generation is not supported for JSON schema constraints."
)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
"""
return super().parse_streaming_increment(new_text, tools)
def structure_info(self) -> callable:
"""
Return a function that creates StructureInfo for constrained generation.
This is not used for JSON schema constraints as they are handled
by the constraint backends directly.
"""
raise NotImplementedError("structure_info not used for JSON schema constraints")

View File

@@ -1,13 +1,10 @@
import json
from json import JSONDecodeError, JSONDecoder
from json.decoder import WHITESPACE
from typing import Any, List, Literal, Optional, Tuple, Union
from typing import Any, Tuple
import partial_json_parser
from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
"""
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except (JSONDecodeError, IndexError) as e:
msg = getattr(e, "msg", str(e))
if "Extra data" in msg or "pop from empty list" in msg:
start = WHITESPACE.match(input_str, 0).end()
obj, end = JSONDecoder().raw_decode(input_str, start)
return obj, end
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
return True
except JSONDecodeError:
return False
def _get_tool_schema_defs(tools: List[Tool]) -> dict:
"""
Get consolidated $defs from all tools, validating for conflicts.
Args:
tools: List of tools to process
Returns:
Dictionary of consolidated $defs from all tools
Raises:
ValueError: If conflicting $defs are found
"""
all_defs = {}
for tool in tools:
if tool.function.parameters is None:
continue
defs = tool.function.parameters.get("$defs", {})
for def_name, def_schema in defs.items():
if def_name in all_defs and all_defs[def_name] != def_schema:
raise ValueError(
f"Tool definition '{def_name}' has "
"multiple schemas, which is not "
"supported."
)
else:
all_defs[def_name] = def_schema
return all_defs
def _get_tool_schema(tool: Tool) -> dict:
return {
"properties": {
"name": {"type": "string", "enum": [tool.function.name]},
"parameters": (
tool.function.parameters
if tool.function.parameters
else {"type": "object", "properties": {}}
),
},
"required": ["name", "parameters"],
}
def get_json_schema_constraint(
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
) -> Optional[dict]:
"""
Get the JSON schema constraint for the specified tool choice.
Args:
tool_choice: The tool choice specification
Returns:
JSON schema dict, or None if no valid tools found
"""
if isinstance(tool_choice, ToolChoice):
# For specific function choice, return the user's parameters schema directly
fn_name = tool_choice.function.name
for tool in tools:
if tool.function.name == fn_name:
return {
"type": "array",
"minItems": 1,
"maxItems": 1,
"items": _get_tool_schema(tool),
}
return None
elif tool_choice == "required":
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [_get_tool_schema(tool) for tool in tools],
},
}
json_schema_defs = _get_tool_schema_defs(tools)
if json_schema_defs:
json_schema["$defs"] = json_schema_defs
return json_schema
return None

View File

@@ -36,9 +36,9 @@ message SamplingParams {
float presence_penalty = 6;
float repetition_penalty = 7;
optional int32 max_new_tokens = 8;
int32 max_new_tokens = 8;
repeated string stop = 9;
repeated uint32 stop_token_ids = 10;
repeated int32 stop_token_ids = 10;
bool skip_special_tokens = 11;
bool spaces_between_special_tokens = 12;
@@ -47,24 +47,24 @@ message SamplingParams {
string regex = 13;
string json_schema = 14;
string ebnf_grammar = 15;
string structural_tag = 16;
}
// LoRA adapter
string lora_path = 17;
string lora_path = 16;
// Speculative decoding
int32 n = 18; // Number of samples
int32 n = 17; // Number of samples
// Token healing
bool token_healing = 19;
bool token_healing = 18;
// Additional parameters
int32 min_new_tokens = 20;
bool ignore_eos = 21;
bool no_stop_trim = 22;
int32 stream_interval = 23;
map<string, float> logit_bias = 24;
int32 min_new_tokens = 19;
bool ignore_eos = 20;
bool no_stop_trim = 21;
int32 stream_interval = 22;
map<string, float> logit_bias = 23;
string structural_tag = 24;
// Custom parameters for extensibility
google.protobuf.Struct custom_params = 25;
@@ -98,7 +98,7 @@ message GenerateRequest {
bool return_logprob = 5;
int32 logprob_start_len = 6;
int32 top_logprobs_num = 7;
repeated uint32 token_ids_logprob = 8;
repeated int32 token_ids_logprob = 8;
bool return_hidden_states = 9;
// For disaggregated serving
@@ -122,14 +122,11 @@ message GenerateRequest {
// For load balancing
int32 dp_balance_id = 17;
// Whether client wants streaming response
bool stream = 18;
}
message TokenizedInput {
string original_text = 1; // For reference
repeated uint32 input_ids = 2;
repeated int32 input_ids = 2;
}
message MultimodalInputs {
@@ -166,50 +163,51 @@ message GenerateResponse {
}
message GenerateStreamChunk {
// Generated tokens (incremental chunk)
repeated uint32 token_ids = 1;
// Generated token
int32 token_id = 1;
string text = 2;
// Cumulative counts
int32 prompt_tokens = 2;
int32 completion_tokens = 3;
int32 cached_tokens = 4;
// Output logprobs (if requested) - incremental for streaming
LogProbs output_logprobs = 5;
// Hidden states (if requested)
repeated float hidden_states = 6;
// Input logprobs (if requested) - only in first chunk
LogProbs input_logprobs = 7;
}
message GenerateComplete {
// Final output
repeated uint32 output_ids = 1;
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
string finish_reason = 2;
// Token usage counts
int32 prompt_tokens = 3;
int32 completion_tokens = 4;
int32 cached_tokens = 5;
// Output logprobs if requested (cumulative)
LogProbs output_logprobs = 6;
// Logprobs (if requested)
LogProbs logprobs = 6;
// Hidden states (if requested)
repeated float hidden_states = 7;
// Metadata
float generation_time = 8; // Time to generate this token
int32 queue_time = 9; // Time spent in queue
}
message GenerateComplete {
// Final output
repeated int32 output_ids = 1;
string output_text = 2;
// Finish reason
enum FinishReason {
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 3;
// All logprobs if requested
repeated LogProbs all_logprobs = 11;
// All hidden states if requested
repeated HiddenStates all_hidden_states = 7;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
// Input logprobs if requested (for prompt tokens)
LogProbs input_logprobs = 10;
repeated HiddenStates all_hidden_states = 12;
}
message GenerateError {
@@ -224,11 +222,15 @@ message LogProbs {
// Top logprobs at each position
repeated TopLogProbs top_logprobs = 3;
// Decoded text for tokens
repeated string token_texts = 4;
}
message TopLogProbs {
repeated float values = 1;
repeated int32 token_ids = 2;
repeated string token_texts = 3;
}
message HiddenStates {
@@ -283,9 +285,10 @@ message EmbedComplete {
// Additional metadata
int32 embedding_dim = 4;
float generation_time = 5;
// For batch embeddings
repeated Embedding batch_embeddings = 5;
repeated Embedding batch_embeddings = 6;
}
message Embedding {

File diff suppressed because one or more lines are too long

View File

@@ -3,6 +3,7 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float
top_p: float
@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
regex: str
json_schema: str
ebnf_grammar: str
structural_tag: str
lora_path: str
n: int
token_healing: bool
@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
no_stop_trim: bool
stream_interval: int
logit_bias: _containers.ScalarMap[str, float]
structural_tag: str
custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
class GenerateRequest(_message.Message):
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
LORA_ID_FIELD_NUMBER: _ClassVar[int]
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
request_id: str
tokenized: TokenizedInput
mm_inputs: MultimodalInputs
@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
lora_id: str
data_parallel_rank: int
dp_balance_id: int
stream: bool
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
class TokenizedInput(_message.Message):
__slots__ = ("original_text", "input_ids")
@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
TEXT_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
token_id: int
text: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
output_logprobs: LogProbs
logprobs: LogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: LogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
generation_time: float
queue_time: int
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str
prompt_tokens: int
completion_tokens: int
cached_tokens: int
output_logprobs: LogProbs
output_text: str
finish_reason: GenerateComplete.FinishReason
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
matched_token_id: int
matched_stop_str: str
input_logprobs: LogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
class LogProbs(_message.Message):
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class TopLogProbs(_message.Message):
__slots__ = ("values", "token_ids")
__slots__ = ("values", "token_ids", "token_texts")
VALUES_FIELD_NUMBER: _ClassVar[int]
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
values: _containers.RepeatedScalarFieldContainer[float]
token_ids: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
token_texts: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
class HiddenStates(_message.Message):
__slots__ = ("values", "layer", "position")
@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
class EmbedComplete(_message.Message):
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
embedding: _containers.RepeatedScalarFieldContainer[float]
prompt_tokens: int
cached_tokens: int
embedding_dim: int
generation_time: float
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
class Embedding(_message.Message):
__slots__ = ("values", "index")

View File

@@ -1,6 +1,3 @@
# This file is auto-generated. Do not edit manually.
# Regenerate with: python compile_proto.py
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

View File

@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
return config
def _load_deepseek_v32_model(
model_path: str,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
):
# first get the local path
local_path = download_from_hf(model_path)
# then load the config file in json
config_file = os.path.join(local_path, "config.json")
if not os.path.exists(config_file):
raise RuntimeError(f"Can't find config file in {local_path}.")
with open(config_file, "r") as f:
config_json = json.load(f)
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
config_json["model_type"] = "deepseek_v3"
tmp_path = os.path.join(local_path, "_tmp_config_folder")
os.makedirs(tmp_path, exist_ok=True)
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
with open(unique_path, "w") as f:
json.dump(config_json, f)
return AutoConfig.from_pretrained(
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
@lru_cache_frozenset(maxsize=32)
def get_config(
model: str,
@@ -140,9 +171,17 @@ def get_config(
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model = client.get_local_dir()
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
except ValueError as e:
if not "deepseek_v32" in str(e):
raise e
config = _load_deepseek_v32_model(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if (
config.architectures is not None
and config.architectures[0] == "Phi4MMForCausalLM"

View File

@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
assert len(k.shape) == 3
assert len(v.shape) == 3
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
if forward_batch.forward_mode.is_extend():
if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func(
q,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
import custom_ops
import torch
import torch_npu
from torch.nn.functional import scaled_dot_product_attention
@@ -36,6 +37,8 @@ class ForwardMetadata:
seq_lens_cpu_int: Optional[torch.Tensor] = None
seq_lens_cpu_list: Optional[List[int]] = None
seq_lens_list_cumsum: Optional[List[int]] = None
seq_lens: Optional[torch.Tensor] = None
actual_seq_lengths_q: Optional[torch.Tensor] = None
class AscendAttnBackend(AttentionBackend):
@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
if self.use_mla:
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.q_head_dim = (
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
)
self.native_attn = TorchNativeAttnBackend(model_runner)
self.graph_metadata = {}
self.max_context_len = model_runner.model_config.context_len
@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
if forward_batch.is_extend_in_batch:
seq_lens_list_cumsum[-1] = (
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
) * tp_size
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
self.graph_mode = False
@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
metadata.seq_lens = seq_lens
metadata.actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
)
self.graph_metadata[bs] = metadata
self.forward_metadata = metadata
@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
metadata.block_tables[bs:, :].fill_(0)
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
self.forward_metadata = metadata
self.graph_mode = True
@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_sparse(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: torch.Tensor = None,
):
is_prefill = forward_batch.forward_mode.is_extend()
if save_kv_cache:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
q_nope, q_pe = q, q_rope
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
block_table = self.forward_metadata.block_tables
if is_prefill:
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
else:
if self.forward_metadata.actual_seq_lengths_q is None:
actual_seq_qlen = (
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
)
else:
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
if self.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = self.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
attn_out = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
query_rope=q_pe,
key_rope=k_pe,
sparse_indices=topk_indices,
scale_value=layer.scaling,
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
block_table=block_table,
sparse_block_size=1,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
return attn_out
def forward_extend(
self,
q,
@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi_head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
):
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if not self.use_mla:
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
):
if is_mla_preprocess_enabled():
# MLAPO does saving kv_cache
save_kv_cache = False
if topk_indices is not None:
return self.forward_sparse(
q,
k,
v,
layer,
forward_batch,
save_kv_cache,
q_rope,
k_rope,
topk_indices,
)
if self.graph_mode:
return self.forward_decode_graph(

View File

@@ -1,7 +1,3 @@
import logging
logger = logging.getLogger(__name__)
ATTENTION_BACKENDS = {}
@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
return AscendAttnBackend(runner)
@register_attention_backend("nsa")
def create_nsa_backend(runner):
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
return NativeSparseAttnBackend(runner)
@register_attention_backend("triton")
def create_triton_backend(runner):
assert not runner.model_config.is_encoder_decoder, (
@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner)
def attn_backend_wrapper(runner, full_attn_backend):
"""
Wrapper for special models like hybrid GDN, so we don't
need to change the code of the original attention backend.
"""
assert not (
runner.is_hybrid_gdn and runner.use_mla_backend
), "hybrid_gdn can only be used with non-MLA models."
@register_attention_backend("hybrid_linear_attn")
def create_hybrid_linear_attn_backend(runner):
assert (
runner.is_hybrid_gdn
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
)
from sglang.srt.utils import is_blackwell, is_npu
# wrap for hybrid GDN models
if runner.is_hybrid_gdn:
from sglang.srt.utils import is_blackwell, is_npu
if is_npu():
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
full_attn_backend = AscendAttnBackend(runner)
elif is_blackwell():
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
full_attn_backend = TritonAttnBackend(runner)
else:
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
)
full_attn_backend = FlashAttentionBackend(runner)
return full_attn_backend
linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
)

View File

@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
if TYPE_CHECKING:
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
def support_triton(self):
"""Check if the current backend supports triton."""
return True
def get_indexer_metadata(
self,
layer_id: int,
forward_batch: ForwardBatch,
) -> Optional[BaseIndexerMetadata]:
"""Get the indexer metadata. None means don't support indexer."""
return None

View File

@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
and self.fa_impl_ver != 4
):
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)

View File

@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.utils import (
get_int_env_var,
is_flashinfer_available,
@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
):
# Keep the signature for type checking. It will be assigned during runtime.
@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
):
if use_ragged:
@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
):
@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask = None
else:
assert isinstance(
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(

View File

@@ -5,13 +5,20 @@ Support attention backend for FlashMLA.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union
import torch
import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.layers.attention.nsa.utils import (
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
NSA_KV_CACHE_STORE_FP8,
compute_nsa_seqlens,
)
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -74,10 +81,17 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.scaling = model_runner.model_config.scaling
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
self.nsa_index_topk = (
get_nsa_index_topk(model_runner.model_config.hf_config)
if self.use_nsa
else None
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
@@ -100,10 +114,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads,
1,
mla_metadata, num_splits = _get_mla_metadata_wrapped(
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
seq_len_q=1,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata,
@@ -130,10 +146,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
mla_metadata, num_splits = _get_mla_metadata_wrapped(
cache_seqlens=seq_lens.to(torch.int32),
seq_len_q=self.num_draft_tokens,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
@@ -162,20 +180,28 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
cuda_graph_kv_indices = block_kv_indices
if self.num_draft_tokens:
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
torch.ones(
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
),
self.num_draft_tokens * self.num_q_heads,
1,
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
_get_mla_metadata_wrapped(
cache_seqlens=torch.ones(
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
),
seq_len_q=self.num_draft_tokens,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
)
else:
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
torch.ones(
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
),
self.num_q_heads,
1,
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
_get_mla_metadata_wrapped(
cache_seqlens=torch.ones(
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
),
seq_len_q=1,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
)
self.cuda_graph_kv_indices = cuda_graph_kv_indices
@@ -201,10 +227,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
1,
mla_metadata, num_splits = _get_mla_metadata_wrapped(
cache_seqlens=seq_lens.to(torch.int32),
seq_len_q=1,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -226,10 +254,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
mla_metadata, num_splits = _get_mla_metadata_wrapped(
cache_seqlens=seq_lens.to(torch.int32),
seq_len_q=self.num_draft_tokens,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -275,10 +305,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
1,
mla_metadata, num_splits = _get_mla_metadata_wrapped(
cache_seqlens=seq_lens.to(torch.int32),
seq_len_q=1,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -300,10 +332,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
mla_metadata, num_splits = _get_mla_metadata_wrapped(
cache_seqlens=seq_lens.to(torch.int32),
seq_len_q=self.num_draft_tokens,
num_heads_q=self.num_q_heads,
num_heads_k=1,
nsa_index_topk=self.nsa_index_topk,
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -335,6 +369,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
topk_indices: Optional[torch.Tensor] = None,
):
cache_loc = forward_batch.out_cache_loc
@@ -349,13 +384,14 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
k_cache = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
if self.data_type == torch.float8_e4m3fn:
if (not self.use_nsa) and self.data_type == torch.float8_e4m3fn:
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
o, _ = flash_mla_with_kvcache(
q=reshape_q_fp8,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
k_cache=k_cache,
block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
@@ -369,17 +405,49 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
block_table = self.forward_metadata.block_kv_indices[:bs]
cache_seqlens = forward_batch.seq_lens.to(torch.int32)
extra_kwargs: Dict
if self.use_nsa:
assert topk_indices is not None
extra_kwargs = dict(
indices=_compute_indices_in_kvcache(
block_table=block_table,
topk_indices=topk_indices.to(torch.int32),
page_size=self.page_size,
),
# doc says it is not used, but if pass in None then error
block_table=block_table,
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
)
cache_seqlens = compute_nsa_seqlens(
cache_seqlens, nsa_index_topk=self.nsa_index_topk
)
else:
extra_kwargs = dict(
block_table=block_table,
causal=True,
)
if (
self.use_nsa
and NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
and not NSA_KV_CACHE_STORE_FP8
):
# inefficiently quantize the whole cache
k_cache = quantize_k_cache(k_cache)
# todo: need check all causal True or False?
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
block_table=self.forward_metadata.block_kv_indices[:bs],
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
k_cache=k_cache,
cache_seqlens=cache_seqlens,
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=layer.scaling,
causal=True,
**extra_kwargs,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -539,3 +607,52 @@ class FlashMLAMultiStepDraftBackend:
)
self.common_template(forward_batch, call_fn)
def _get_mla_metadata_wrapped(
*,
cache_seqlens: torch.Tensor,
seq_len_q: int,
num_heads_q: int,
num_heads_k: int,
nsa_index_topk: Optional[int],
):
if nsa_index_topk is not None:
assert nsa_index_topk is not None
return get_mla_metadata(
cache_seqlens=cache_seqlens,
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
# but the name looks like need seq_len_q?
num_q_tokens_per_head_k=seq_len_q * num_heads_q // num_heads_k,
num_heads_k=num_heads_k,
num_heads_q=num_heads_q,
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
topk=nsa_index_topk,
)
else:
assert nsa_index_topk is None
return get_mla_metadata(
cache_seqlens=cache_seqlens,
num_heads_per_head_k=seq_len_q * num_heads_q // num_heads_k,
num_heads_k=num_heads_k,
)
# TODO speedup
def _compute_indices_in_kvcache(block_table, topk_indices, page_size):
topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
1
)
block_idx = block_table[idx0, topk_indices_safe // page_size]
offset = topk_indices_safe % page_size
indices_in_kvcache = block_idx * page_size + offset
# the kernel requires invalid entry to be -1
assert indices_in_kvcache.shape == topk_indices.shape
indices_in_kvcache[topk_indices == -1] = -1
# return: (batch_size, seqlen_q_ori, topk)
indices_in_kvcache = indices_in_kvcache[:, None, :]
return indices_in_kvcache

View File

@@ -3,6 +3,7 @@ from typing import Optional, Union
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
return backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> Optional[BaseIndexerMetadata]:
backend = self._select_backend(forward_batch.forward_mode)
return backend.get_indexer_metadata(layer_id, forward_batch)

View File

@@ -0,0 +1,121 @@
import math
from typing import Optional, Tuple, List
import torch
def cdiv(x: int, y: int):
return (x+y-1) // y
def native_mla_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
s_q, _, d_qk = q.size()
s_kv = kv.size(0)
topk = indices.size(-1)
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
indices = indices[:, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
qs = q.float() # [s_q, h_q, d_qk]
kvs = kv[ :, 0, :].float() # [s_kv, d_qk]
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(s_q, topk, d_qk) # [s_q, topk, d_qk]
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
attn_score *= sm_scale * math.log2(math.e)
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score @ kvs[:, :, :d_v]
return (max_logits, lse, result)
def native_mla_with_kvcache(
q: torch.Tensor, # [batch_size, s_q, h_q, d]
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
block_table: torch.Tensor, # [batch_size, ?]
cache_seqlens: torch.Tensor, # [batch_size]
dv: int,
is_causal: bool,
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation in PyTorch
"""
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
for i in range(s_q):
cur_indices = indices[i]
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention(
batch_idx: int,
query: torch.Tensor, # [h_q, s_q, d]
kv: torch.Tensor, # [h_kv, s_k, d]
dv: int,
is_causal,
indices: Optional[torch.Tensor], # [s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
h_q = query.size(0)
h_kv = kv.size(0)
s_q = query.shape[-2]
s_k = kv.shape[-2]
query = query.float()
kv = kv.float()
if h_kv != 1:
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None:
mask = torch.ones(s_q, s_k, dtype=torch.bool)
if is_causal:
assert indices is None
mask = mask.tril(diagonal=s_k - s_q)
if indices is not None:
mask &= get_topk_attn_mask(s_q, s_k, indices)
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype)
attn_weight /= math.sqrt(query.size(-1))
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output, lse
b, s_q, h_q, d = q.size()
block_size = blocked_k.size(1)
h_kv = blocked_k.size(2)
cache_seqlens_cpu = cache_seqlens.cpu()
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0: cur_num_blocks]
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
cur_out, cur_lse = scaled_dot_product_attention(
i,
q[i].transpose(0, 1),
cur_kv.transpose(0, 1),
dv,
is_causal,
indices[i] if indices is not None else None
)
out_ref[i] = cur_out.transpose(0, 1)
lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16)
return out_ref, lse_ref

View File

@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
self.rotary_emb = rotary_emb
self.layer_id = layer_id
self.has_preprocess_weights = False
self.dtype = None
self.q_lora_rank = self.q_b_proj.input_size # 1536
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
self.num_local_heads = num_local_heads # tp
self.qk_nope_head_dim = qk_nope_head_dim # 128
self.qk_rope_head_dim = qk_rope_head_dim # 64
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
def preprocess_weights(self, hidden_states):
self.dummy = torch.empty(
@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
return k_cache, v_cache, slot_mapping
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
def forward_absorb_prepare_npu_rms_norm_cache(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch,
zero_allocator,
):
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
self.dtype = hidden_states.dtype
self.cos, self.sin = self.get_sin_cos(positions)
self.kvCache, self.kvCacheRope, self.slotmapping = (
self.get_kv_cache_and_cache_idx(forward_batch)
)
if not self.has_preprocess_weights:
self.has_preprocess_weights = True
cos, sin = self.cos, self.sin
if self.q_lora_rank is not None:
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q_lowrank)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) # b*s,n,d
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
latent_cache = latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
) # (B*S,N,1,D)
cache_mode = "PA_BNSD"
self.kvCache = self.kvCache.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.kv_lora_rank,
)
self.kvCacheRope = self.kvCacheRope.view(
-1,
forward_batch.attn_backend.page_size,
1,
forward_batch.attn_backend.qk_rope_head_dim,
)
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
latent_cache,
self.kv_a_layernorm.weight,
cos,
sin,
self.slotmapping.to(torch.int64),
self.kvCacheRope,
self.kvCache,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
input_dtype = hidden_states.dtype
if not self.has_preprocess_weights:
self.preprocess_weights(hidden_states)
@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
zero_allocator,
positions,
)
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
_is_w8a8 = (
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
and self.qkv_a_proj.quant_method.quantization_config.get_name()
== "w8a8_int8"
)
if _is_w8a8:
return self.forward_mlapo(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb_prepare_npu_rms_norm_cache(
positions, hidden_states, forward_batch, zero_allocator
)

View File

@@ -0,0 +1,3 @@
from .topk import fast_topk, fast_topk_transform
__all__ = ["fast_topk", "fast_topk_transform"]

View File

@@ -0,0 +1,505 @@
#include <ATen/core/TensorBase.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/python.h>
namespace {
constexpr int TopK = 2048;
constexpr int kThreadsPerBlock = 1024;
constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
struct FastTopKParams {
const float *__restrict__ input; // [B, input_stride]
int32_t *__restrict__ indices; // [B, TopK]
int32_t *__restrict__ lengths; // [B]
int64_t input_stride;
bool use_tilelang;
};
// when length <= TopK, we can directly write the indices
__device__ void naive_topk_cuda(const float *__restrict__ score,
int32_t *__restrict__ indice, int32_t length) {
const auto tid = threadIdx.x;
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
indice[i] = (i < length) ? i : -1;
}
}
// keep the first `length` entries, set others to -1
__device__ void
naive_topk_transform(const float *__restrict__ score, int32_t length,
int32_t *__restrict__ dst_page_table,
const int32_t *__restrict__ src_page_table) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
}
}
__device__ __forceinline__ uint8_t convert_to_uint8(float x) {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits & 0xFFFF)
: static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ __forceinline__ uint32_t convert_to_uint32(float x) {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? (~bits & 0xFFFFFFFFu) : (bits | 0x80000000u);
}
template <bool Is_Epilogue = false, typename Indexer, typename Loader,
int LENGTH, int MAX_REMAIN>
__device__ __forceinline__ auto
radix_topk(Indexer indexer, Loader loader, uint32_t length, int topk,
int *__restrict__ index, int &__restrict__ s_counter,
int (&__restrict__ s_histogram)[LENGTH],
int &__restrict__ s_remain_cnt,
int (&__restrict__ s_remain_idx)[MAX_REMAIN]) -> int {
constexpr auto RADIX = LENGTH - 1;
static_assert(RADIX > 1 && (RADIX & (RADIX - 1)) == 0,
"RADIX must be power of 2");
static_assert(RADIX <= kThreadsPerBlock);
__shared__ uint32_t s_threshold_bin_id;
const auto tx = threadIdx.x;
if (tx < RADIX + 1)
s_histogram[tx] = 0;
__syncthreads();
/// NOTE: Use uint32_t as the index
for (auto i = tx; i < length; i += kThreadsPerBlock) {
const auto idx = indexer(i);
const auto bin = loader(idx);
::atomicAdd(&s_histogram[bin], 1);
}
__syncthreads();
// cumsum (descending)
if (tx == 0) {
s_histogram[RADIX] = 0;
s_remain_cnt = 0;
for (int i = RADIX - 2; i >= 0; --i) {
s_histogram[i] += s_histogram[i + 1];
}
// threshold bin
for (int i = 0; i < RADIX; i++) {
if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
s_threshold_bin_id = i;
break;
}
}
}
__syncthreads();
const auto threshold_bin = s_threshold_bin_id;
const auto new_topk = topk - s_histogram[threshold_bin + 1];
for (auto i = tx; i < length; i += kThreadsPerBlock) {
const auto idx = indexer(i);
const auto bin_id = static_cast<uint32_t>(loader(idx));
if (bin_id > threshold_bin) {
index[::atomicAdd(&s_counter, 1)] = idx;
} else if (bin_id == threshold_bin && new_topk > 0) {
if constexpr (Is_Epilogue) {
index[::atomicAdd(&s_counter, 1)] = idx;
} else {
if (const auto cnt = ::atomicAdd(&s_remain_cnt, 1);
C10_LIKELY(cnt < MAX_REMAIN)) {
s_remain_idx[cnt] = idx;
}
}
}
}
__syncthreads();
return new_topk;
}
__device__ void fast_topk_cuda(const float *__restrict__ input,
int *__restrict__ index, int length,
int topk = TopK) {
constexpr auto RADIX = 256;
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
__shared__ int s_histogram[RADIX + 1];
__shared__ int s_num_input[2];
__shared__ int s_counter;
// allocate for two rounds
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
s_counter = 0;
// collect candidates
const auto indexer = [](int idx) { return idx; };
const auto loader = [&input](int idx) {
return convert_to_uint8(input[idx]);
};
int new_topk = radix_topk(indexer, loader, length, topk, index, s_counter,
s_histogram, s_num_input[0], s_input_idx[0]);
if (new_topk <= 0)
return;
// round 0
const auto indexer_0 = [](int idx) { return s_input_idx[0][idx]; };
const auto loader_0 = [&input](int idx) {
return (convert_to_uint32(input[idx]) >> 24) & 0xFF;
};
new_topk = radix_topk(indexer_0, loader_0, s_num_input[0], new_topk, index,
s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
if (new_topk <= 0)
return;
// round 1
const auto indexer_1 = [](int idx) { return s_input_idx[1][idx]; };
const auto loader_1 = [&input](int idx) {
return (convert_to_uint32(input[idx]) >> 16) & 0xFF;
};
new_topk = radix_topk(indexer_1, loader_1, s_num_input[1], new_topk, index,
s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
if (new_topk <= 0)
return;
// round 2
const auto loader_2 = [&input](int idx) {
return (convert_to_uint32(input[idx]) >> 8) & 0xFF;
};
new_topk = radix_topk(indexer_0, loader_2, s_num_input[0], new_topk, index,
s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
if (new_topk <= 0)
return;
// round 3
const auto loader_3 = [&input](int idx) {
return convert_to_uint32(input[idx]) & 0xFF;
};
// epilogue
radix_topk<true>(indexer_1, loader_3, s_num_input[1], new_topk, index,
s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
}
__device__ void fast_topk_cuda_tl(const float *__restrict__ input,
int *__restrict__ index, int length,
int topk = TopK) {
constexpr auto BLOCK_SIZE = 1024;
constexpr auto RADIX = 256;
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
__shared__ int s_threshold_bin_id;
__shared__ int s_histogram[RADIX + 1];
__shared__ int s_num_input[2];
__shared__ int s_counter;
// allocate for two rounds
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
int tx = threadIdx.x;
// stage 1: 8bit coarse histogram
if (tx < RADIX + 1)
s_histogram[tx] = 0;
__syncthreads();
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto bin = convert_to_uint8(input[idx]);
::atomicAdd(&s_histogram[bin], 1);
}
__syncthreads();
// cumsum (descending)
if (tx == 0) {
for (int i = RADIX - 2; i >= 0; --i) {
s_histogram[i] += s_histogram[i + 1];
}
// threshold bin
for (int i = 0; i < RADIX; i++) {
if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
s_threshold_bin_id = i;
break;
}
}
s_num_input[0] = 0;
s_counter = 0;
}
__syncthreads();
int threshold_bin = s_threshold_bin_id;
int new_topk = topk - s_histogram[threshold_bin + 1];
// collect candidates
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto bin_id = static_cast<int>(convert_to_uint8(input[idx]));
if (bin_id > threshold_bin) {
int pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else if (bin_id == threshold_bin && new_topk > 0) {
int pos = ::atomicAdd(&s_num_input[0], 1);
if (pos < SMEM_INPUT_SIZE) {
[[likely]] s_input_idx[0][pos] = idx;
}
}
}
__syncthreads();
// stage 2: refine with 8bit radix passes
#pragma unroll 4
for (int round = 0; round < 4; ++round) {
if (new_topk <= 0)
break;
int r_idx = round % 2;
// reset
if (tx < RADIX + 1)
s_histogram[tx] = 0;
__syncthreads();
int num_input = s_num_input[r_idx];
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
int idx = s_input_idx[r_idx][i];
uint32_t bin32 =
(convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
::atomicAdd(&s_histogram[bin32], 1);
}
__syncthreads();
if (tx == 0) {
for (int i = RADIX - 2; i >= 0; --i)
s_histogram[i] += s_histogram[i + 1];
for (int i = 0; i < RADIX; i++) {
if (s_histogram[i] >= new_topk && s_histogram[i + 1] < new_topk) {
s_threshold_bin_id = i;
break;
}
}
s_num_input[r_idx ^ 1] = 0;
}
__syncthreads();
new_topk -= s_histogram[s_threshold_bin_id + 1];
int threshold_bin = s_threshold_bin_id;
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
int idx = s_input_idx[r_idx][i];
uint32_t bin32 =
(convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
if (bin32 > threshold_bin) {
int pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else if (bin32 == threshold_bin && new_topk > 0) {
if (round == 3) {
int pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else {
int pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
if (pos < SMEM_INPUT_SIZE)
s_input_idx[r_idx ^ 1][pos] = idx;
}
}
}
__syncthreads();
}
}
__global__ void topk_kernel(const FastTopKParams params) {
const auto &[input, indices, lengths, input_stride, use_tilelang] = params;
const auto bid = blockIdx.x;
const auto length = *(lengths + bid);
const auto indice = indices + bid * TopK;
const auto score = input + bid * input_stride;
if (length <= TopK) {
return naive_topk_cuda(score, indice, length);
} else {
if (use_tilelang) {
return fast_topk_cuda_tl(score, indice, length);
} else {
return fast_topk_cuda(score, indice, length);
}
}
}
__global__ void topk_kernel_transform_decode( // decode
const FastTopKParams params, int32_t *__restrict__ dst_page_table,
const int32_t *__restrict__ src_page_table, const int64_t src_stride) {
const auto &[input, _, lengths, input_stride, use_tilelang] = params;
const auto bid = blockIdx.x;
const auto tid = threadIdx.x;
const auto length = *(lengths + bid);
const auto src_page_entry = src_page_table + bid * src_stride;
const auto dst_page_entry = dst_page_table + bid * TopK;
const auto score = input + bid * input_stride;
if (length <= TopK) {
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
} else {
__shared__ int s_indices[TopK];
if (use_tilelang) {
fast_topk_cuda_tl(score, s_indices, length);
} else {
fast_topk_cuda(score, s_indices, length);
}
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
}
__global__ void topk_kernel_transform_prefill( // prefill
const FastTopKParams params, int32_t *__restrict__ dst_page_table,
const int32_t *__restrict__ src_page_table, const int64_t src_stride,
const int32_t *__restrict__ cu_seqlens, const int64_t prefill_bs) {
const auto &[input, _, lengths, input_stride, use_tilelang] = params;
const auto bid = blockIdx.x;
const auto tid = threadIdx.x;
const auto length = *(lengths + bid);
const auto dst_page_entry = dst_page_table + bid * TopK;
const auto score = input + bid * input_stride;
/// NOTE: prefill bs is usually small, we can just use a simple loop here
/// We ensure that last cu_seqlens is equal to number of blocks launched
assert(gridDim.x == cu_seqlens[prefill_bs] &&
"Invalid cu_seqlens in topk-transform-prefill");
__shared__ const int32_t *s_src_page_entry;
if (tid == 0) {
for (int64_t offset = 0; offset < prefill_bs; ++offset) {
if (bid < cu_seqlens[offset + 1]) {
s_src_page_entry = src_page_table + offset * src_stride;
break;
}
}
}
__syncthreads();
const auto src_page_entry = s_src_page_entry;
if (length <= TopK) {
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
} else {
__shared__ int s_indices[TopK];
if (use_tilelang) {
fast_topk_cuda_tl(score, s_indices, length);
} else {
fast_topk_cuda(score, s_indices, length);
}
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
}
auto get_params(at::Tensor score, at::Tensor lengths, bool use_tilelang,
std::optional<at::Tensor> indices_opt = std::nullopt)
-> FastTopKParams {
const auto B = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
TORCH_CHECK(lengths.size(0) == B);
int32_t *indices_data_ptr = nullptr;
if (indices_opt.has_value()) {
const auto &indices = indices_opt.value();
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
TORCH_CHECK(indices.size(0) == B);
TORCH_CHECK(indices.size(1) == TopK);
indices_data_ptr = indices.data_ptr<int32_t>();
}
return FastTopKParams{
.input = score.data_ptr<float>(),
.indices = indices_data_ptr,
.lengths = lengths.data_ptr<int32_t>(),
.input_stride = score.stride(0),
.use_tilelang = use_tilelang,
};
}
template <auto *f, size_t max_dynamic_smem>
auto setup_kernel_smem_once() -> void {
[[maybe_unused]]
static const auto result = [] {
return ::cudaFuncSetAttribute(
f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem);
}();
TORCH_CHECK(result == cudaSuccess,
"set_up_kernel_once failed:", ::cudaGetErrorString(result));
}
auto fast_topk_interface(at::Tensor score, at::Tensor indices,
at::Tensor lengths, bool use_tilelang) -> void {
const auto params = get_params(score, lengths, use_tilelang, indices);
const auto B = score.size(0);
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
setup_kernel_smem_once<topk_kernel, kSmem>();
topk_kernel<<<grid, block, kSmem, stream>>>(params);
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess,
"topk kernel failed:", ::cudaGetErrorString(result));
}
auto fast_topk_transform_interface(at::Tensor score, at::Tensor lengths,
at::Tensor dst_page_table,
at::Tensor src_page_table,
at::Tensor cu_seqlens,
bool use_tilelang) -> void {
const auto params = get_params(score, lengths, use_tilelang);
const auto B = score.size(0);
TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
TORCH_CHECK(cu_seqlens.dim() == 1 && cu_seqlens.is_contiguous());
const auto prefill_bs = cu_seqlens.size(0) - 1;
TORCH_CHECK(dst_page_table.size(0) == B);
TORCH_CHECK(dst_page_table.size(1) == TopK);
TORCH_CHECK(src_page_table.size(0) == prefill_bs);
TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
// launch kernel
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
const auto src_stride = src_page_table.stride(0);
// dispatch to decode or prefill
const auto is_decode = (prefill_bs == B);
if (is_decode) {
setup_kernel_smem_once<topk_kernel_transform_decode, kSmem>();
topk_kernel_transform_decode<<<grid, block, kSmem, stream>>>(
params, dst_page_table.data_ptr<int32_t>(),
src_page_table.data_ptr<int32_t>(), src_stride);
} else {
setup_kernel_smem_once<topk_kernel_transform_prefill, kSmem>();
topk_kernel_transform_prefill<<<grid, block, kSmem, stream>>>(
params, dst_page_table.data_ptr<int32_t>(),
src_page_table.data_ptr<int32_t>(), src_stride,
cu_seqlens.data_ptr<int32_t>(), prefill_bs);
}
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess,
"topk kernel failed:", ::cudaGetErrorString(result));
}
} // namespace
PYBIND11_MODULE(topk_kernel, m) {
m.def("fast_topk", &fast_topk_interface);
m.def("fast_topk_transform", &fast_topk_transform_interface);
}

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
from typing import Any
import torch
from .utils import load_kernel_module
def _load_topk_module() -> Any:
"""
Load the index manipulation module.
"""
return load_kernel_module("topk.cu", "topk_kernel")
# TODO(dark): configure out why my cuda impl is a little slower....
# I believe it has something to do with unrolling loops (?)
_USE_TL = True
def fast_topk(
score: torch.Tensor,
indices: torch.Tensor,
lengths: torch.Tensor,
) -> torch.Tensor:
return _load_topk_module().fast_topk(score, indices, lengths, _USE_TL)
def fast_topk_transform(
score: torch.Tensor,
lengths: torch.Tensor,
dst_page_table: torch.Tensor,
src_page_table: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return _load_topk_module().fast_topk_transform(
score, lengths, dst_page_table, src_page_table, cu_seqlens, _USE_TL
)

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import os
from functools import lru_cache
from typing import Any, Iterable
@lru_cache()
def _prepare_for_load() -> str:
import os
import warnings
warnings.filterwarnings(
"ignore", category=UserWarning, module="torch.utils.cpp_extension"
)
return os.path.dirname(os.path.abspath(__file__))
@lru_cache()
def load_kernel_module(
path: str | Iterable[str],
name: str,
*,
build: str = "build",
cflags: Iterable[str] | None = None,
cuda_flags: Iterable[str] | None = None,
ldflags: Iterable[str] | None = None,
) -> Any:
from torch.utils.cpp_extension import load
if isinstance(path, str):
path = (path,)
abs_path = _prepare_for_load()
build_dir = f"{abs_path}/{build}"
os.makedirs(build_dir, exist_ok=True)
return load(
name=name,
sources=[f"{abs_path}/csrc/{p}" for p in path],
extra_cflags=list(cflags or []) or ["-O3", "-std=c++17"],
extra_cuda_cflags=list(cuda_flags or []) or ["-O3", "-std=c++17"],
extra_ldflags=list(ldflags or []) or None,
build_directory=build_dir,
)

View File

@@ -0,0 +1,163 @@
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST
def dequantize_k_cache(quant_k_cache):
if NSA_DEQUANT_K_CACHE_FAST:
return _dequantize_k_cache_fast_wrapped(quant_k_cache)
else:
return _dequantize_k_cache_slow(quant_k_cache)
def _dequantize_k_cache_slow(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576,
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty(
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_nope * cur_scales
)
result = result.view(num_blocks, block_size, 1, d)
return result
def _dequantize_k_cache_fast_wrapped(
quant_k_cache: torch.Tensor,
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks, block_size, _, dim_quant = quant_k_cache.shape
assert dv == 512
assert dim_quant == 656
assert tile_size == 128
quant_k_cache = quant_k_cache.view((-1, dim_quant))
output = _dequantize_k_cache_fast(quant_k_cache)
return output.view(num_blocks, block_size, 1, -1)
def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
num_tokens, dim_quant = quant_k_cache.shape
assert quant_k_cache.dtype == torch.float8_e4m3fn
dim_nope = 512
dim_rope = 64
num_tiles = dim_nope // group_size
assert dim_quant == 656
output = torch.empty(
(num_tokens, dim_nope + dim_rope),
dtype=torch.bfloat16,
device=quant_k_cache.device,
)
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
input_nope_q = quant_k_cache[:, :dim_nope]
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
torch.float32
)
input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
_dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
output,
input_nope_q,
input_nope_s,
input_rope,
output.stride(0),
input_nope_q.stride(0),
input_nope_s.stride(0),
input_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
)
return output
@triton.jit
def _dequantize_k_cache_fast_kernel(
output_ptr,
input_nope_q_ptr,
input_nope_s_ptr,
input_rope_ptr,
output_stride_0: int,
input_nope_q_stride_0: int,
input_nope_s_stride_0: int,
input_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
):
token_id = tl.program_id(0)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. dequant nope
effective_block_id = raw_block_id
offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs_q < DIM_NOPE
ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q
ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id
y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
y_s = tl.load(ptr_s)
y = (y_q * y_s).to(output_ptr.dtype.element_ty)
dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
tl.store(dst_ptr, y, mask=mask)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs
dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
raise Exception("UT is in quant_k_cache.py")

View File

@@ -0,0 +1,135 @@
# fallback_fp8.py
# PyTorch fallback implementation for DeepGEMM-like fp8 logits ops
from sglang.srt.utils import ceil_div
import torch
@torch.no_grad()
def fallback_fp8_mqa_logits(q: torch.Tensor,
kv: torch.Tensor,
weights: torch.Tensor,
ks: torch.Tensor,
ke: torch.Tensor, cost_only: bool = False) -> torch.Tensor:
seq_len_kv = kv.shape[0]
if cost_only:
start = ks.clamp(min=0, max=seq_len_kv)
end = ke.clamp(min=0, max=seq_len_kv)
count_ones_per_row = (end - start).clamp(min=0)
return count_ones_per_row.sum()
k = kv
q = q.float()
k = k.float()
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < ke[:, None]
mask = mask_lo & mask_hi
score = torch.einsum('mhd,nd->hmn', q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float('-inf'))
#cost = mask.sum()
return logits
# """
# PyTorch fallback for fp8_mqa_logits.
# No real fp8 used, just FP32.
# Args:
# q: (M, H, D) query
# k: (N, D) key
# weights: (M, H)
# ks: (M,) int32
# ke: (M,) int32
# Returns:
# logits: (M, N) with -inf outside of valid range
# """
# M, H, D = q.shape
# N = k[0].shape[0]
# logits = torch.full((M, N), float("-inf"), dtype=torch.float32, device=q.device)
# # for i in range(M):
# # start = max(ks[i].item(), 0)
# # end = min(ke[i].item(), N)
# # if start >= end:
# # continue
# # qi = q[i] # (H, D)
# # ki = k[start:end] # (L, D)
# # sim = torch.matmul(qi, ki.T) # (H, L)
# # weighted_sim = (sim.relu() * weights[i].unsqueeze(-1)).sum(dim=0) # (L,)
# # logits[i, start:end] = weighted_sim
# return logits
@torch.no_grad()
def fallback_fp8_paged_mqa_logits(q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int) -> torch.Tensor:
batch_size, next_n, heads, dim = q.size()
num_block, block_size, _, dim = kv_cache.size()
logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32)
context_lens = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens[i]
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
for block_rk in range(ceil_div(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device=q.device)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None])
s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf'))
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf'))
return logits
"""
PyTorch fallback for fp8_paged_mqa_logits.
No real fp8 used, just FP32.
Args:
q: (B, N, H, D)
kv_cache: (num_blocks, block_size, 1, D)
weights: (B * N, H)
context_lens: (B,)
block_tables: (B, max_blocks)
max_model_len: int
Returns:
logits: (B * N, max_model_len)
"""
B, N, H, D = q.shape
block_size = kv_cache.shape[1]
logits = torch.full((B * N, max_model_len), float("-inf"), dtype=torch.float32, device=q.device)
for i in range(B):
ctx_len = context_lens[i].item()
q_offsets = torch.arange(ctx_len - N, ctx_len, device=q.device)
weight_slice = weights[i * N:(i + 1) * N, :].transpose(0, 1).contiguous()
for br in range((ctx_len + block_size - 1) // block_size):
blk_idx = block_tables[i, br].item()
if blk_idx < 0:
continue
qx = q[i] # (N, H, D)
kx = kv_cache[blk_idx] # (block_size, 1, D)
kx = kx.squeeze(1) # (block_size, D)
k_offsets = torch.arange(br * block_size, (br + 1) * block_size, device=q.device)
mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None]) # (N, block_size)
s = torch.where(mask[None, :, :],
torch.einsum('nhd,ld->hnl', qx, kx),
torch.full((H, N, block_size), float("-inf"), device=q.device))
s = s.relu() * weight_slice[..., None]
logits_slice = s.sum(dim=0) # (N, block_size)
mask_block = (k_offsets[None, :] <= q_offsets[:, None])
logits[i * N:(i + 1) * N, br * block_size:(br + 1) * block_size] = \
torch.where(mask_block, logits_slice, float("-inf"))
return logits

View File

@@ -0,0 +1,354 @@
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
"""
k: data, 128 item per token, fp8
s: scale, 1 item per token, fp32
"""
class GetK:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
@classmethod
def slow(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
seq_len_ = num_pages * pool.page_size
index_k_fp8 = torch.empty(
(seq_len_, pool.index_head_dim),
dtype=torch.uint8,
device=pool.device,
)
for i in range(num_pages):
page_index = page_indices[i]
index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
page_index
][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
return index_k_fp8[:seq_len]
@classmethod
def torch_fast(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim), uint8
"""
# can handle per 128B instead of per element
# page_indices: (num_pages,), element := a page index
buf_numel_per_page = buf.shape[1]
num_k_bytes_per_page = pool.page_size * pool.index_head_dim
num_k_bytes_per_token = pool.index_head_dim
# buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
# flat_buf: (whatever,), uint8
flat_buf = buf.flatten()
# flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
num_k_bytes_per_page, dtype=torch.int32, device="cuda"
)[None, :]
flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
out = flat_buf[flat_indices]
return out.view(-1, 128)
class GetS:
@classmethod
def execute(cls, *args, **kwargs):
return cls.torch_fast(*args, **kwargs)
@classmethod
def slow(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
seq_len_ = num_pages * pool.page_size
assert pool.index_head_dim // pool.quant_block_size == 1
index_k_scale_fp8 = torch.empty(
(seq_len_, 4),
dtype=torch.uint8,
device=pool.device,
)
for i in range(num_pages):
page_index = page_indices[i]
index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
page_index
][pool.page_size * pool.index_head_dim :].view(-1, 4)
return index_k_scale_fp8[:seq_len]
@classmethod
def torch_fast(
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
):
"""
:param page_indices: (num_pages,), int32
:return: (seq_len, index_head_dim // quant_block_size), uint8
"""
buf_numel_per_page = buf.shape[1]
num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
s_offset_in_page = pool.page_size * pool.index_head_dim
flat_buf = buf.flatten()
flat_indices = (
(page_indices * buf_numel_per_page)[:, None]
+ torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
None, :
]
+ s_offset_in_page
)
flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
out = flat_buf[flat_indices]
return out.view(-1, 4)
class SetK:
@classmethod
def execute(cls, *args, buf, **kwargs):
return cls.torch_fast(*args, **kwargs, buf=buf)
@classmethod
def slow(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
):
for i in range(len(loc)):
page_index = loc[i] // pool.page_size
offset = loc[i] % pool.page_size
buf[
page_index,
offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
] = index_k[i].view(torch.uint8)
@classmethod
def torch_fast(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
):
(num_tokens_to_write,) = loc.shape
buf_numel_per_page = buf.shape[1]
num_k_bytes_per_token = pool.index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index = loc // pool.page_size
loc_token_offset_in_page = loc % pool.page_size
flat_buf = buf.flatten()
flat_indices = (
(loc_page_index * buf_numel_per_page)[:, None]
+ (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
+ torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
None, :
]
)
num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
flat_indices = flat_indices.flatten()[:num_k_bytes_total]
flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
class SetS:
@classmethod
def execute(cls, *args, buf, **kwargs):
return cls.torch_fast(*args, **kwargs, buf=buf)
@classmethod
def slow(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k_scale: torch.Tensor,
):
for i in range(len(loc)):
page_index = loc[i] // pool.page_size
offset = loc[i] % pool.page_size
start = pool.page_size * pool.index_head_dim
buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
index_k_scale[i].view(torch.uint8)
)
@classmethod
def torch_fast(
cls,
pool: "NSATokenToKVPool",
buf: torch.Tensor,
loc: torch.Tensor,
index_k_scale: torch.Tensor,
):
(num_tokens_to_write,) = loc.shape
buf_numel_per_page = buf.shape[1]
num_s_bytes_per_token = 4
s_offset_in_page = pool.page_size * pool.index_head_dim
# loc: (num_tokens_to_write,), int32, element := the token index to write to
loc_page_index = loc // pool.page_size
loc_token_offset_in_page = loc % pool.page_size
flat_buf = buf.flatten()
flat_indices = (
(loc_page_index * buf_numel_per_page)[:, None]
+ s_offset_in_page
+ (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
+ torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
None, :
]
)
number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
flat_indices = flat_indices.flatten()[:number_s_bytes_total]
flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
class SetKAndS:
@classmethod
def execute(cls, *args, buf, **kwargs):
if 0:
# print("SetK, SetS comparison test")
buf_cloned = buf.clone()
cls.vanilla(*args, **kwargs, buf=buf)
cls.triton(*args, **kwargs, buf=buf_cloned)
def _clear_token_0(target):
target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
_clear_token_0(buf)
_clear_token_0(buf_cloned)
assert torch.all(
buf == buf_cloned
), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
return
cls.triton(*args, **kwargs, buf=buf)
@classmethod
def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
@classmethod
def triton(cls, pool, buf, loc, index_k, index_k_scale):
_set_k_and_s_triton(
buf=buf,
loc=loc,
index_k=index_k,
index_k_scale=index_k_scale,
page_size=pool.page_size,
)
def _set_k_and_s_triton(
buf: torch.Tensor,
loc: torch.Tensor,
index_k: torch.Tensor,
index_k_scale: torch.Tensor,
page_size: int,
):
"""
:param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
:param loc: (num_tokens_to_write,), int, element := the token index to write to
:param index_k: (num_tokens_to_write, 128 elem), fp8
:param index_k_scale: (num_tokens_to_write, 1 elem), fp32
:return:
"""
num_pages, buf_numel_per_page = buf.shape
(num_tokens_to_write,) = loc.shape
num_tokens_to_write_, index_head_dim = index_k.shape
num_tokens_to_write__, scale_dim = index_k_scale.shape
assert buf_numel_per_page == 64 * (128 + 4)
assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
assert index_head_dim == 128
assert scale_dim == 1
assert page_size == 64
assert buf.dtype == torch.uint8
assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
assert index_k.dtype == torch.float8_e4m3fn
assert index_k_scale.dtype == torch.float32
assert buf.is_contiguous()
assert loc.is_contiguous()
assert index_k.is_contiguous()
assert index_k_scale.is_contiguous()
buf_fp8 = buf.view(torch.float8_e4m3fn)
buf_fp32 = buf.view(torch.float32)
_set_k_and_s_triton_kernel[(num_tokens_to_write,)](
buf_fp8,
buf_fp32,
loc,
index_k,
index_k_scale,
index_k.stride(0),
PAGE_SIZE=page_size,
BUF_NUMEL_PER_PAGE=buf_numel_per_page,
NUM_K_ELEMS_PER_TOKEN=index_head_dim,
S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
)
@triton.jit
def _set_k_and_s_triton_kernel(
buf_fp8_ptr,
buf_fp32_ptr,
loc_ptr,
index_k_ptr,
index_k_scale_ptr,
index_k_ptr_stride_0,
PAGE_SIZE: tl.constexpr,
BUF_NUMEL_PER_PAGE: tl.constexpr,
NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
):
token_id = tl.program_id(0)
loc = tl.load(loc_ptr + token_id)
in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
# no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
k = tl.load(index_k_ptr + in_k_offsets)
k_scale = tl.load(index_k_scale_ptr + token_id)
loc_page_index = loc // PAGE_SIZE
loc_token_offset_in_page = loc % PAGE_SIZE
out_k_offsets = (
loc_page_index * BUF_NUMEL_PER_PAGE
+ loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
+ tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
)
# "//4" b/c it is fp32 instead of uint8
out_s_offset = (
loc_page_index * BUF_NUMEL_PER_PAGE // 4
+ S_OFFSET_NBYTES_IN_PAGE // 4
+ loc_token_offset_in_page
)
tl.store(buf_fp8_ptr + out_k_offsets, k)
tl.store(buf_fp32_ptr + out_s_offset, k_scale)

View File

@@ -0,0 +1,709 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from sglang.srt.layers.attention.nsa.fallback_fp8 import fallback_fp8_mqa_logits, fallback_fp8_paged_mqa_logits
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.utils import add_prefix, is_npu
if not is_npu():
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
#import deep_gemm
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import add_prefix, align, is_cuda
# try:
# import deep_gemm_v32
# except ImportError as e:
# print("Error when importing deep_gemm_v32, try deep_gemm")
# try:
# import deep_gemm as deep_gemm_v32
# except ImportError as e:
# print("Error when importing deep_gemm, skip")
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
class BaseIndexerMetadata(ABC):
@abstractmethod
def get_seqlens_int32(self) -> torch.Tensor:
"""
Return: (batch_size,) int32 tensor
"""
@abstractmethod
def get_page_table_64(self) -> torch.Tensor:
"""
Return: (batch_size, num_blocks) int32, page table.
The page size of the table is 64.
"""
@abstractmethod
def get_seqlens_expanded(self) -> torch.Tensor:
"""
Return: (sum_extend_seq_len,) int32 tensor
"""
@abstractmethod
def topk_transform(
self,
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
"""
Perform topk selection on the logits and possibly transform the result.
NOTE that attention backend may override this function to do some
transformation, which means the result of this topk_transform may not
be the topk indices of the input logits.
Return: Anything, since it will be passed to the attention backend
for further processing on sparse attention computation.
Don't assume it is the topk indices of the input logits.
"""
def hadamard_transform_pytorch(x: torch.Tensor, scale: float) -> torch.Tensor:
"""
A native PyTorch implementation of the Fast Hadamard Transform that mimics
the behavior of the custom CUDA kernel's call signature.
Args:
x (torch.Tensor): Input tensor of shape (*, N), where N is a power of 2.
scale (float): The normalization factor to multiply the result by.
Returns:
torch.Tensor: The Hadamard transformed tensor.
"""
# Base case for recursion
if x.shape[-1] == 1:
return x
# Split the tensor into two halves
half_size = x.shape[-1] // 2
a = x[..., :half_size]
b = x[..., half_size:]
# Recursive calls
a_transformed = hadamard_transform_pytorch(a, scale=1.0) # No scaling in intermediate steps
b_transformed = hadamard_transform_pytorch(b, scale=1.0) # No scaling in intermediate steps
# Combine the results
combined = torch.cat([a_transformed + b_transformed, a_transformed - b_transformed], dim=-1)
# Apply the scale only at the final step
return combined * scale
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
#from fast_hadamard_transform import hadamard_transform
hidden_size = x.size(-1)
assert (
hidden_size & (hidden_size - 1)
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
return hadamard_transform_pytorch(x, scale=hidden_size**-0.5)
class V32LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(
x.float(), (self.dim,), self.weight, self.bias, self.eps
).type_as(x)
class Indexer(CustomOp):
def __init__(
self,
hidden_size: int,
index_n_heads: int,
index_head_dim: int,
rope_head_dim: int,
index_topk: int,
q_lora_rank: int,
max_position_embeddings: int,
rope_theta: float,
layer_id: int,
scale_fmt: Optional[str],
block_size: int = 128,
rope_scaling: Optional[Dict[str, Any]] = None,
prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = index_n_heads
self.head_dim = index_head_dim
self.rope_head_dim = rope_head_dim
self.index_topk = index_topk
self.q_lora_rank = q_lora_rank
self.layer_id = layer_id
self.alt_stream = alt_stream
if not is_npu():
self.sm_count = torch.cuda.get_device_properties(0).multi_processor_count
self.half_device_sm_count = align(self.sm_count // 2, 8)
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wq_b", prefix),
)
self.wk = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("wk", prefix),
)
self.k_norm = V32LayerNorm(self.head_dim)
# NOTE: weight_proj is not quantized
self.weights_proj = ReplicatedLinear(
self.hidden_size,
self.n_heads,
bias=False,
prefix=add_prefix("weights_proj", prefix),
)
self.rotary_emb = get_rope_wrapper(
rope_head_dim,
rotary_dim=rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
self.block_size = block_size
self.scale_fmt = scale_fmt
self.softmax_scale = self.head_dim**-0.5
def _forward_fake(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
):
bs = x.shape[0]
assert self.index_topk == 2048
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
None, ...
].repeat(bs, 1)
if forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.seq_lens_cpu is not None
)
which = 0
for i, (kv_len, qo_len) in enumerate(
zip(
forward_batch.seq_lens_cpu.tolist(),
forward_batch.extend_seq_lens_cpu,
strict=True,
)
):
for j in range(kv_len - qo_len, kv_len):
ans[which, j + 1 :] = -1
which += 1
assert which == ans.shape[0]
else:
assert forward_batch.seq_lens_cpu is not None
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
ans[i, seq_len:] = -1
return ans
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
def _get_q_k_bf16(
self,
q_lora: torch.Tensor,
x: torch.Tensor,
positions: torch.Tensor,
enable_dual_stream: bool,
):
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
self.half_device_sm_count
):
query, _ = self.wq_b(q_lora)
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
q_rope, _ = torch.split(
query,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
)
with torch.cuda.stream(self.alt_stream):
key, _ = self.wk(x)
key = self.k_norm(key)
k_rope, _ = torch.split(
key,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
)
current_stream.wait_stream(self.alt_stream)
else:
query, _ = self.wq_b(q_lora)
if dumper._enable:
after_wq_b = query.clone()
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
q_rope, _ = torch.split(
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
key, _ = self.wk(x)
if dumper._enable:
after_wk = key.clone()
key = self.k_norm(key)
if dumper._enable:
after_k_norm = key.clone()
k_rope, _ = torch.split(
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
)
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
query[..., : self.rope_head_dim] = q_rope
key[..., : self.rope_head_dim] = k_rope
if dumper._enable:
q_before_hadamard = query.clone()
k_before_hadamard = key.clone()
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
query = rotate_activation(query)
with torch.cuda.stream(self.alt_stream):
key = rotate_activation(key)
current_stream.wait_stream(self.alt_stream)
else:
query = rotate_activation(query)
key = rotate_activation(key)
return query, key
def _get_topk_paged(
self,
forward_batch: ForwardBatch,
layer_id: int,
q_fp8: torch.Tensor,
weights: torch.Tensor,
metadata: BaseIndexerMetadata,
) -> torch.Tensor:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
page_size = forward_batch.token_to_kv_pool.page_size
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm_v32
assert page_size == 64, "only support page size 64"
# NOTE(dark): this support extend/decode/decode+graph
block_tables = metadata.get_page_table_64()
max_seq_len = block_tables.shape[1] * page_size
kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
layer_id=layer_id
)
blocksize = page_size
seqlens_32 = metadata.get_seqlens_int32()
# NOTE(dark): 132 is SM count on H200/B200, not magic number
# schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
# seqlens_32, blocksize, self.sm_count
# )
assert len(q_fp8.shape) == 3
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
assert len(kv_cache_fp8.shape) == 2
block_kv = 64
num_heads_kv = 1
head_dim_with_sf = 132
kv_cache_fp8 = kv_cache_fp8.view(
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
)
assert len(weights.shape) == 3
weights = weights.squeeze(2)
logits = fallback_fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
seqlens_32,
block_tables,
max_seq_len,
)
# NOTE(dark): logits should be cleaned in topk_transform
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result
def _get_topk_ragged(
self,
forward_batch: ForwardBatch,
layer_id: int,
q_fp8: torch.Tensor,
weights: torch.Tensor,
metadata: BaseIndexerMetadata,
) -> torch.Tensor:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
page_size = forward_batch.token_to_kv_pool.page_size
assert page_size == 64, "only support page size 64"
assert len(weights.shape) == 3
weights = weights.squeeze(-1)
k_fp8_list = []
k_scale_list = []
ks_list = []
offset = 0
block_tables = metadata.get_page_table_64()
assert (
forward_batch.seq_lens_cpu is not None
and forward_batch.extend_seq_lens_cpu is not None
)
for i in range(forward_batch.batch_size):
seq_len = forward_batch.seq_lens_cpu[i].item()
assert isinstance(seq_len, int)
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
layer_id,
seq_len,
block_tables[i],
)
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
layer_id,
seq_len,
block_tables[i],
)
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
k_fp8_list.append(k_fp8)
k_scale_list.append(k_scale)
ks_list.append(ks)
offset += extend_seq_len
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
kv_fp8 = (k_fp8, k_scale)
ks = torch.cat(ks_list, dim=0)
seq_lens_expanded = metadata.get_seqlens_expanded()
ke = ks + seq_lens_expanded
logits = fallback_fp8_mqa_logits(
q_fp8,
k_fp8,
weights,
ks,
ke
)
assert logits.shape[0] == len(seq_lens_expanded)
topk_result = metadata.topk_transform(logits, self.index_topk)
return topk_result
def _forward(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
metadata = forward_batch.attn_backend.get_indexer_metadata(
layer_id, forward_batch
)
enable_dual_stream = (
NSA_DUAL_STREAM
and self.alt_stream is not None
and get_is_capture_mode()
and q_lora.shape[0] > 0
and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
)
# skip NSA if attention backend choose to skip this batch
if metadata is None:
return None
if not NSA_USE_REAL_INDEXER: # temporary
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
q_fp8 = query.to(torch.float32)
k_fp8 = key.to(torch.float32)
q_scale = torch.ones((query.shape[0], 1), dtype=torch.float32, device="cuda")
k_scale = torch.ones((key.shape[0], 1), dtype=torch.float32, device="cuda")
if enable_dual_stream:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
with torch.cuda.stream(self.alt_stream):
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
current_stream.wait_stream(self.alt_stream)
else:
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
# k_fp8: (seq_len, head_dim) fp8_e4m3fn
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
layer_id=layer_id,
loc=forward_batch.out_cache_loc,
index_k=k_fp8,
index_k_scale=k_scale,
)
weights = self._get_logits_head_gate(x, q_scale)
assert forward_batch.seq_lens_cpu is not None
if len(forward_batch.seq_lens_cpu) == 0:
# this seems b/c max-pad, no worries?
# if x.shape[0] != 0:
# print(
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
# )
return torch.full(
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
)
if forward_batch.forward_mode.is_decode_or_idle():
topk_result = self._get_topk_paged(
forward_batch, layer_id, q_fp8, weights, metadata
)
else:
topk_result = self._get_topk_ragged(
forward_batch, layer_id, q_fp8, weights, metadata
)
return topk_result
def forward_cuda(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> Optional[torch.Tensor]:
return self._forward(x, q_lora, positions, forward_batch, layer_id)
def forward_npu(
self,
x: torch.Tensor,
q_lora: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: int,
) -> torch.Tensor:
import custom_ops
import torch_npu
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.utils import get_bool_env_var
if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
else:
actual_seq_lengths_kv = (
forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
)
enable_index_cp = (
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
)
is_prefill = forward_batch.forward_mode.is_extend()
attention_tp_rank = get_attention_tp_rank()
attention_tp_size = get_attention_tp_size()
cos_sin = self.rotary_emb.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
if is_prefill and enable_index_cp:
slice_length = cos.shape[0] // attention_tp_size
cos = cos[
slice_length
* attention_tp_rank : slice_length
* (attention_tp_rank + 1)
]
sin = sin[
slice_length
* attention_tp_rank : slice_length
* (attention_tp_rank + 1)
]
slot_mapping = forward_batch.out_cache_loc
block_table = forward_batch.attn_backend.forward_metadata.block_tables
bs = x.shape[0]
q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
q_pe, q_nope = torch.split(
q,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64, 64 + 64]
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
bs, self.n_heads, self.rope_head_dim
) # [bs, n, d]
q = torch.cat([q_pe, q_nope], dim=-1)
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
k = self.k_norm(k_proj)
k_pe, k_nope = torch.split(
k,
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
dim=-1,
) # [bs, 64 + 64]
k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
bs, 1, self.rope_head_dim
) # [bs, 1, d]
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
if is_prefill and enable_index_cp:
k, local_k = (
torch.empty(
(k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
dtype=k.dtype,
device=k.device,
),
k,
)
get_attention_tp_group().all_gather_into_tensor(k, local_k)
forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
indexer_input = {}
if is_prefill:
actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
device=q.device
)
if enable_index_cp:
actual_seq_lengths_q -= bs * attention_tp_rank
actual_seq_lengths_q = torch.max(
actual_seq_lengths_q,
torch.zeros_like(actual_seq_lengths_q).to(
device=actual_seq_lengths_q.device
),
)
actual_seq_lengths_q = torch.min(
actual_seq_lengths_q,
torch.full(actual_seq_lengths_q.shape, bs).to(
device=actual_seq_lengths_q.device
),
)
else:
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
)
else:
actual_seq_lengths_q = (
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
)
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
x = x.view(-1, self.hidden_size)
weights = self.weights_proj(x)[0]
block_table = (
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
)
topk_indices = torch.ops.custom.npu_lightning_indexer(
query=q.view(-1, self.n_heads, self.head_dim),
key=past_key_states,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=self.index_topk,
sparse_mode=3,
)
if is_prefill and enable_index_cp:
topk_indices, local_topk_indices = (
torch.empty(
(
topk_indices.shape[0] * attention_tp_size,
topk_indices.shape[1],
topk_indices.shape[2],
),
dtype=topk_indices.dtype,
device=topk_indices.device,
),
topk_indices,
)
get_attention_tp_group().all_gather_into_tensor(
topk_indices, local_topk_indices
)
return topk_indices

View File

@@ -0,0 +1,255 @@
import torch
import triton
import triton.language as tl
from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST
def quantize_k_cache(cache_k):
# TODO upstream can skip concat([k_nope, k_pe]) since we split them here
if NSA_QUANT_K_CACHE_FAST:
return _quantize_k_cache_fast_wrapped(cache_k)
else:
return _quantize_k_cache_slow(cache_k)
# Copied from original
def _quantize_k_cache_slow(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty(
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
dtype=torch.float8_e4m3fn,
device=input_k_cache.device,
)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = (
torch.abs(
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
)
.max(dim=-1)
.values
/ 448.0
) # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (
input_k_cache[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].float()
/ cur_scale_factors_inv.float()
).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_quantized_nope
)
result = result.view(num_blocks, block_size, 1, -1)
return result
def _quantize_k_cache_fast_wrapped(
input_k_cache: torch.Tensor,
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape
assert dv == 512
assert dim_nope_and_rope == 512 + 64
assert tile_size == 128
input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))
# TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
k_nope = input_k_cache[:, :dv]
k_rope = input_k_cache[:, dv:]
output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)
return output.view(num_blocks, block_size, 1, -1)
def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):
"""
:param k_nope: (num_tokens, dim_nope 512)
:param k_rope: (num_tokens, dim_rope 64)
"""
assert k_nope.dtype == torch.bfloat16
assert k_rope.dtype == torch.bfloat16
num_tokens, dim_nope = k_nope.shape
num_tokens_, dim_rope = k_rope.shape
assert num_tokens == num_tokens_
assert dim_nope == 512
assert dim_rope == 64
assert k_nope.dtype == k_rope.dtype
num_tiles = dim_nope // group_size
assert k_nope.stride(1) == 1
assert k_rope.stride(1) == 1
output = torch.empty(
(num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),
dtype=torch.float8_e4m3fn,
device=k_nope.device,
)
output_nope_q = output[..., :dim_nope]
output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)
output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
_quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
output_nope_q,
output_nope_s,
output_rope,
k_nope,
k_rope,
output_nope_q.stride(0),
output_nope_s.stride(0),
output_rope.stride(0),
k_nope.stride(0),
k_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
)
return output
@triton.jit
def _quantize_k_cache_fast_kernel(
output_nope_q_ptr,
output_nope_s_ptr,
output_rope_ptr,
k_nope_ptr,
k_rope_ptr,
output_nope_q_stride_0: int,
output_nope_s_stride_0: int,
output_rope_stride_0: int,
k_nope_stride_0: int,
k_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
FP8_MIN: tl.constexpr,
FP8_MAX: tl.constexpr,
):
token_id = tl.program_id(0)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. quant nope
effective_block_id = raw_block_id
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_NOPE
ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs
y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
# the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
y_s = tl.max(tl.abs(y)) / FP8_MAX
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(
output_nope_q_ptr.dtype.element_ty
)
dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs
dst_s_ptr = (
output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id
)
tl.store(dst_q_ptr, y_q, mask=mask)
tl.store(dst_s_ptr, y_s)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs
dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs
data = tl.load(src_ptr, mask=mask)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
for num_blocks, block_size in [
(1, 1),
(10, 64),
]:
dim_nope_and_rope = 512 + 64
input_k_cache = torch.randn(
(num_blocks, block_size, 1, dim_nope_and_rope),
dtype=torch.bfloat16,
device="cuda",
)
# temp debug
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
ref_quant = _quantize_k_cache_slow(input_k_cache)
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
# print(f"{input_k_cache=}")
# print(f"{ref_quant=}")
# print(f"{actual_quant=}")
# print(f"{ref_quant == actual_quant=}")
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
# print(f"{ref_quant.view(torch.bfloat16)=}")
# print(f"{actual_quant.view(torch.bfloat16)=}")
# assert torch.all(ref_quant == actual_quant)
import dequant_k_cache
ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(
actual_quant
)
print(f"{ref_ref_dequant=}")
print(f"{actual_actual_dequant=}")
print(f"{actual_actual_dequant - ref_ref_dequant=}")
print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}")
# TODO too different?
torch.testing.assert_close(
ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2
)
torch.testing.assert_close(
ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
)
print("Passed")

View File

@@ -0,0 +1,813 @@
from typing import Optional, Tuple
# import tilelang
# import tilelang.language as T
import torch
# tilelang.set_log_level("WARNING")
# pass_configs = {
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
# }
BF16 = "bfloat16"
FP8 = "float8_e4m3"
FP32 = "float32"
'''
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
):
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale else 2
blk_m = 32
group_size = 128
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
s_local = T.alloc_fragment((blk_m,), scale_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = s_local[i]
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), "Input tensor must be contiguous"
assert (
x.size(-1) % block_size == 0
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
N = x.size(-1)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
return y, s
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int):
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
blk_n1 = 512
blk_n2 = 128
@T.prim_func
def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
q_smem = T.alloc_shared((h, d), FP8)
T.copy(q[i_b, i_m, 0, 0], q_smem)
q_s_frag = T.alloc_fragment(h, FP32)
T.copy(q_s[i_b, i_m, 0], q_s_frag)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
logits = T.alloc_fragment((blk_n2, h), FP32)
T.gemm(
k_smem,
q_smem,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
logits_sum = T.alloc_fragment(blk_n2, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
return fp8_index_kernel_
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor:
"""
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_attention_fwd_kernel_v1(
num_heads,
dim,
tail_dim,
topk,
*,
kv_group=1,
sm_scale=None,
is_causal=True,
block_I=64,
num_stages=2,
threads=256,
):
assert dim == tilelang.math.next_power_of_2(
dim
), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim
), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert (
topk % block_I == 0
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = num_heads // kv_group
q_shape = [batch, seq_len, num_heads, dim + tail_dim]
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
o_shape = [batch, seq_len, num_heads, dim]
indices_shape = [batch, seq_len, kv_group, topk]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
):
with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
bx,
by,
bz,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
O_shared = T.alloc_shared([H_per_block, D], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_i, g_i = by, bz
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i
]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i
]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
mask[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, O_shared)
T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
return main
@tilelang.jit(
out_idx=[-1],
compile_flags=[
"-O3",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10",
"-DNDEBUG",
],
) # type: ignore
def sparse_attention_fwd_kernel_v2(
num_heads: int,
dim: int,
tail_dim: int,
topk: int,
*,
kv_group: int = 1,
sm_scale: Optional[float] = None,
block_I: int = 64,
):
assert dim == tilelang.math.next_power_of_2(
dim
), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim
), f"haven't check padding correctness yet, dim={tail_dim}"
assert (
topk % block_I == 0
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
threads = 384
batch = T.symbolic("batch")
qo_len = T.symbolic("seq_len")
num_pages = T.symbolic("num_pages")
q_shape = [batch, qo_len, num_heads, dim + tail_dim]
kv_shape = [batch, num_pages, kv_group, dim + tail_dim]
o_shape = [batch, qo_len, num_heads, dim]
indices_shape = [batch, qo_len, kv_group, topk]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
H = num_heads
padded_H = max(tilelang.math.next_power_of_2(num_heads), 16)
if padded_H != H:
assert kv_group == 1
BI = block_I
NI = tilelang.cdiv(topk, block_I)
assert NI % 2 == 0, "NI should be a multiple of 2"
D = dim
D_tail = tail_dim
if num_heads > 64:
assert num_heads % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = num_heads // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
):
"""
Q: [b, qo_len, H, D + D_tail] (bfloat16)
KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)
Indices: [b, qo_len, kv_group, topk] (int32)
"""
with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz): # type: ignore
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)
KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)
KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)
KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)
K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)
K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)
O_shared_l = Q_shared_l
O_shared_r = Q_shared_r
is_kv_valid_0 = T.alloc_shared([BI], "bool", scope="shared")
is_kv_valid_1 = T.alloc_shared([BI], "bool", scope="shared")
acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared")
alpha_local = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
indices_local = T.alloc_local([1], indices_dtype)
indices_tmp = T.alloc_local([1], indices_dtype)
bar_q = T.alloc_barrier(arrive_count=384)
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
bar_k_0_free = T.alloc_barrier(arrive_count=256)
bar_k_1_free = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
bar_0_128 = T.alloc_barrier(arrive_count=128)
bar_1_128 = T.alloc_barrier(arrive_count=128)
bar_2_128 = T.alloc_barrier(arrive_count=128)
bar_final = T.alloc_barrier(arrive_count=128)
b_i, g_i = by, bz
s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
tx = T.get_thread_binding()
T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
# with sync_at(bar_0_128, 0):
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
T.barrier_arrive(bar_0_128)
T.barrier_wait(bar_0_128, 0)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_tail_shared,
K_tail_shared_0,
acc_s,
transpose_B=True,
wg_wait=-1,
)
T.wait_wgmma(0)
if i_i != 0:
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(
acc_s, sumexp_i, dim=1
) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_0_free[0])
# Buffer 1
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
T.barrier_arrive(bar_0_128)
T.barrier_wait(bar_0_128, 1)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(
is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype)
)
T.gemm(
Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1
)
T.gemm(
Q_tail_shared,
K_tail_shared_1,
acc_s,
transpose_B=True,
wg_wait=-1,
)
T.wait_wgmma(0)
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
)
T.reduce_sum(
acc_s, sumexp_i, dim=1
) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_1_free[0])
# Rescale
for h_i in T.Parallel(H_per_block):
sum_exp_shared[h_i] = sumexp[h_i]
T.barrier_arrive(bar_final)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
elif tx >= 128 and tx < 256:
# T.set_max_nreg(168, 1)
T.fill(acc_o_r, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
T.barrier_arrive(bar_1_128)
T.barrier_wait(bar_1_128, 0)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
T.barrier_arrive(bar_k_0_free[0])
T.barrier_arrive(bar_sScale_and_sS_free)
# Buffer 1
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
T.barrier_arrive(bar_1_128)
T.barrier_wait(bar_1_128, 1)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
T.barrier_arrive(bar_k_1_free[0])
if i_i != T.ceildiv(NI, 2) - 1:
T.barrier_arrive(bar_sScale_and_sS_free)
# Rescale
T.barrier_wait(bar_final, 0)
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
elif tx >= 256:
# producer
T.set_max_nreg(80, 0)
indices_local[0] = 0
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
T.barrier_arrive(bar_2_128)
T.barrier_wait(bar_2_128, 0)
for r in T.serial(4):
indices_tmp[0] = Indices[
b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8
]
is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
if is_kv_valid_0[r * 16 + (tx - 256) // 8]:
indices_local[0] = indices_tmp[0]
with T.attr("default", "async_scope", 1): # type: ignore
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
64 * u + (tx - 256) % 8 * 8 + v,
]
KV_shared_0_r[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
]
with T.attr("default", "async_scope", 1): # type: ignore
for v in T.vectorized(8):
K_tail_shared_0[
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
] = KV[
b_i,
indices_local[0],
g_i,
D + (tx - 256) % 8 * 8 + v,
]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
T.barrier_arrive(bar_2_128)
T.barrier_wait(bar_2_128, 1)
for r in T.serial(4):
indices_tmp[0] = Indices[
b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8
]
is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
if is_kv_valid_1[r * 16 + (tx - 256) // 8]:
indices_local[0] = indices_tmp[0]
with T.attr("default", "async_scope", 1): # type: ignore
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
64 * u + (tx - 256) % 8 * 8 + v,
]
KV_shared_1_r[
r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 + v,
] = KV[
b_i,
indices_local[0],
g_i,
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
]
with T.attr("default", "async_scope", 1): # type: ignore
for v in T.vectorized(8):
K_tail_shared_1[
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
] = KV[
b_i,
indices_local[0],
g_i,
D + (tx - 256) % 8 * 8 + v,
]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
return main
def tilelang_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> torch.Tensor:
assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3
num_heads = q.shape[1]
dim = q.shape[2]
tail_dim = dim - d_v
topk = indices.shape[-1]
assert topk == 2048
# NOTE(dark): v2 offers better performance than v1
kernel = sparse_attention_fwd_kernel_v2(
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
)
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
'''
def act_quant(
x: torch.Tensor,
block_size: int = 128,
scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
PyTorch fallback for act_quant
Block-wise FP8 E4M3 quantization
"""
if not x.is_contiguous():
x = x.contiguous()
N = x.size(-1)
assert N % block_size == 0, f"Last dim {N} must be divisible by block_size={block_size}"
# Reshape to blocks
x_2d = x.view(-1, N)
x_blocks = x_2d.view(-1, block_size)
# Compute absmax per block
amax = x_blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-4)
# FP8 E4M3 max value is ~448
fp8_max = 448.0
scale = amax / fp8_max
if scale_fmt is not None:
# Simulate rounded scale (power-of-2 rounding)
scale = torch.round(scale * 256) / 256
# Quantize and clamp
y_blocks = torch.clamp(torch.round(x_blocks / scale), -fp8_max, fp8_max)
# Convert to FP8
q = y_blocks.view_as(x_2d).to(torch.float8_e4m3fn)
# Reshape scale
s = scale.view(x_2d.size(0), N // block_size).to(torch.float32)
s = s.view(*x.shape[:-1], N // block_size)
return q.view_as(x), s

View File

@@ -0,0 +1,65 @@
import torch
from sglang.srt.utils import align
# NOTE(dark): flashmla P requires `params.topk % (2*B_TOPK) == 0`,
# where `B_TOPK=64`. So we align to 128 by default.
_TOPK_ALIGNMENT = 128
# TODO(dark): maybe this torch_op can support torch.compile
def _fast_topk_torch(
input: torch.Tensor, seq_lens: torch.Tensor, topk: int, alignment: int
) -> torch.Tensor:
# Fallback to torch.topk
bs, max_seq_len = input.shape
assert len(seq_lens) == bs
# set those out-of-bound input to -inf
padded_max_seq_len = align(max_seq_len, alignment)
positions = torch.arange(
padded_max_seq_len, device=input.device, dtype=seq_lens.dtype
)
positions = positions.unsqueeze(0).expand(bs, -1)
mask = positions >= seq_lens.unsqueeze(1)
# NOTE(dark): just return all valid indices as an optimization
if padded_max_seq_len <= topk:
return positions.masked_fill(mask, -1)
assert topk % alignment == 0
# in-place operation: mask invalid inputs to -inf
input = input.masked_fill_(mask[:, :max_seq_len], float("-inf"))
result = input.topk(topk, dim=-1, sorted=True)
return result.indices.masked_fill_(mask[:, :topk], -1)
def fast_topk_impl(
input: torch.Tensor,
seq_lens: torch.Tensor,
topk: int,
alignment: int = _TOPK_ALIGNMENT,
) -> torch.Tensor:
return _fast_topk_torch(input, seq_lens, topk, alignment)
def fast_topk_transform_fused_cuda(
input: torch.Tensor,
seq_lens: torch.Tensor,
topk: int,
dst_page_table: torch.Tensor,
src_page_table: torch.Tensor,
cu_seqlens_q: torch.Tensor,
alignment: int = _TOPK_ALIGNMENT,
) -> torch.Tensor:
from sglang.srt.layers.attention.nsa.cuda import fast_topk_transform
assert topk == 2048 and topk % alignment == 0
return fast_topk_transform(
score=input,
lengths=seq_lens,
dst_page_table=dst_page_table,
src_page_table=src_page_table,
cu_seqlens=cu_seqlens_q,
)

View File

@@ -0,0 +1,144 @@
from typing import List, Optional
import torch
import triton
import triton.language as tl
def transform_index_page_table_prefill(**kwargs):
return transform_index_page_table_prefill_ref(**kwargs)
def transform_index_page_table_decode(**kwargs):
return transform_index_page_table_decode_ref(**kwargs)
@triton.jit
def transform_index_page_table_decode_kernel(
page_table_ptr: torch.Tensor,
topk_indices_ptr: torch.Tensor,
result_ptr: torch.Tensor,
page_size: tl.constexpr,
max_seqlen_k: tl.constexpr,
):
TOPK: tl.constexpr = 2048
req_id = tl.program_id(0)
page_table_ptr = page_table_ptr + req_id * max_seqlen_k
topk_indices_ptr = topk_indices_ptr + req_id * TOPK
result_ptr = result_ptr + req_id * TOPK
offset = tl.arange(0, TOPK) # topk should be 2048
loaded_topk_indices = tl.load(topk_indices_ptr + offset)
mask = loaded_topk_indices >= 0
loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
tl.store(result_ptr + offset, -1, mask=~mask)
def transform_index_page_table_decode_fast(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
) -> torch.Tensor:
"""
Transform the page table according to topk indices for sparse topk attention.
Args:
page_table: [qo_len, max_seqlen_k], the original page table
topk_indices: [qo_len, topk], the topk indices for each query position
Returns:
transformed_page_table: [qo_len, topk], the transformed page table
For out-of-bound indices in topk_indices, this should be filled with -1.
"""
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
assert topk_indices.shape[1] == 2048
qo_len = topk_indices.shape[0]
max_seqlen_k = page_table.shape[1]
if result is None:
result = torch.empty_like(topk_indices, dtype=torch.int32)
# Launch triton kernel
grid = (qo_len,)
transform_index_page_table_decode_kernel[grid](
page_table,
topk_indices,
result,
page_size,
max_seqlen_k=max_seqlen_k,
)
return result
def transform_index_page_table_prefill_fast(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
) -> torch.Tensor:
# TODO(baizhou): can be implemented with another triton kernel
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert len(extend_lens_cpu) == page_table.shape[0]
offset = 0
for i, l in enumerate(extend_lens_cpu):
transform_index_page_table_decode_fast(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
)
offset += l
assert offset == topk_indices.shape[0]
return result
def transform_index_page_table_decode_ref(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
if result is None:
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert result.shape == topk_indices.shape
torch.gather(
page_table,
dim=1,
index=topk_indices.clamp(min=0).long(),
out=result,
)
result[topk_indices < 0] = -1
return result
def transform_index_page_table_prefill_ref(
page_table: torch.Tensor,
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
assert len(extend_lens_cpu) == page_table.shape[0]
offset = 0
for i, l in enumerate(extend_lens_cpu):
transform_index_page_table_decode_ref(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
)
offset += l
assert offset == topk_indices.shape[0]
return result
if __name__ == "__main__":
bs, topk, max_seqlen = 10, 2048, 3000
page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda")
topk_indices = torch.full((bs, topk), -1, device="cuda")
topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)
ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)
result = transform_index_page_table_decode_fast(page_table, topk_indices)
assert torch.all(result == ref_result)
print("Passed")

View File

@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
class DummyModel(nn.Module):
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
super().__init__()
self.weights_proj = nn.Linear(d_in, 1024)
self.n_heads = n_heads
self.softmax_scale = softmax_scale
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
weights = self.weights_proj(x)
weights = weights * self.n_heads**-0.5
q_scale = q_scale.unsqueeze(1) # (B,1,1)
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
return weights
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
weights = self.weights_proj(x)
q_scale = q_scale.unsqueeze(1) # (B,1,1)
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
return weights
def main():
torch.manual_seed(0)
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
x = torch.randn(128, 2048) # batch=128, d_in=2048
q_scale = torch.randn(128, 1)
import time
start = time.time()
for _ in range(1000):
out_orig = model._get_logits_head_gate_orig(x, q_scale)
print("Original version time:", time.time() - start)
start = time.time()
for _ in range(1000):
out_opt = model._get_logits_head_gate_opt(x, q_scale)
print("Optimized version time:", time.time() - start)
print("Difference:", (out_orig - out_opt).abs().max().item())
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
if __name__ == "__main__":
main()
"""
Original version time: 0.49235057830810547
Optimized version time: 0.4087331295013428
Difference: 1.4901161193847656e-08
"""

View File

@@ -0,0 +1,32 @@
# temp NSA debugging environ
from sglang.srt.utils import get_bool_env_var
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var(
"SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true"
)
NSA_KV_CACHE_STORE_FP8 = get_bool_env_var("SGLANG_NSA_KV_CACHE_STORE_FP8", "false")
NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "false")
NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "false")
def _print_bool_env_vars():
msg = ""
for k, v in globals().items():
if k.startswith("NSA_") and isinstance(v, bool):
msg += f"{k}={v} "
print(msg, flush=True)
_print_bool_env_vars()
if not NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8:
assert not NSA_KV_CACHE_STORE_FP8
def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):
return original_seq_lens.clamp(max=nsa_index_topk)

View File

@@ -0,0 +1,869 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Dict,
List,
Literal,
Optional,
Tuple,
TypeAlias,
Union,
)
import torch
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.layers.attention.nsa.topk import (
fast_topk_impl,
fast_topk_transform_fused_cuda,
)
from sglang.srt.layers.attention.nsa.transform_index import (
transform_index_page_table_decode,
transform_index_page_table_prefill,
)
from sglang.srt.layers.attention.nsa.utils import (
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
NSA_FUSE_TOPK,
NSA_KV_CACHE_STORE_FP8,
compute_nsa_seqlens,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.two_batch_overlap import global_server_args_dict
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
@dataclass(frozen=True)
class NSAFlashMLAMetadata:
"""Metadata only needed by FlashMLA"""
flashmla_metadata: torch.Tensor
num_splits: torch.Tensor
def slice(self, sli):
return NSAFlashMLAMetadata(
flashmla_metadata=self.flashmla_metadata,
num_splits=self.num_splits[sli],
)
def copy_(self, other: "NSAFlashMLAMetadata"):
self.flashmla_metadata.copy_(other.flashmla_metadata)
self.num_splits.copy_(other.num_splits)
@dataclass(frozen=True)
class NSAMetadata:
page_size: int
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor
# Maximum sequence length for query
max_seq_len_q: int
# Maximum sequence length for key
max_seq_len_k: int
# Cumulative sequence lengths for query
cu_seqlens_q: torch.Tensor
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor
# Page table, the index of KV Cache Tables/Blocks
# this table is always with page_size = 1
page_table_1: torch.Tensor
# NOTE(dark): This will property be used in:
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
real_page_table: torch.Tensor
# NSA metadata (nsa prefill are expanded)
nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
nsa_extend_seq_lens_list: List[int]
nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
@dataclass(frozen=True)
class NSAIndexerMetadata(BaseIndexerMetadata):
attn_metadata: NSAMetadata
def get_seqlens_int32(self) -> torch.Tensor:
return self.attn_metadata.cache_seqlens_int32
def get_page_table_64(self) -> torch.Tensor:
return self.attn_metadata.real_page_table
def get_seqlens_expanded(self) -> torch.Tensor:
return self.attn_metadata.nsa_seqlens_expanded
def topk_transform(
self,
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
if not NSA_FUSE_TOPK:
return fast_topk_impl(logits, self.get_seqlens_expanded(), topk)
# NOTE(dark): if fused, we return a transformed page table directly
dst_page_table = torch.empty(
(logits.shape[0], topk), dtype=torch.int32, device=logits.device
)
fast_topk_transform_fused_cuda(
input=logits,
seq_lens=self.get_seqlens_expanded(),
topk=topk,
dst_page_table=dst_page_table,
src_page_table=self.attn_metadata.page_table_1,
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
)
return dst_page_table
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
return torch.nn.functional.pad(
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
)
_NSA_IMPL_T: TypeAlias = Literal[
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
]
NSA_PREFILL_IMPL: _NSA_IMPL_T
NSA_DECODE_IMPL: _NSA_IMPL_T
class NativeSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.forward_metadata: NSAMetadata
self.device = model_runner.device
assert isinstance(model_runner.page_size, int)
self.real_page_size = model_runner.page_size
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
self.max_context_len = model_runner.model_config.context_len
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
assert model_runner.req_to_token_pool is not None
self.req_to_token = model_runner.req_to_token_pool.req_to_token
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
def get_device_int32_arange(self, l: int) -> torch.Tensor:
if l > len(self._arange_buf):
next_pow_of_2 = 1 << (l - 1).bit_length()
self._arange_buf = torch.arange(
next_pow_of_2, device=self.device, dtype=torch.int32
)
return self._arange_buf[:l]
def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
page_size = self.real_page_size
if page_size == 1:
return page_table
max_seqlen_k = page_table.shape[1]
strided_indices = torch.arange(
0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
)
return page_table[:, strided_indices] // page_size
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
batch_size = forward_batch.batch_size
device = forward_batch.seq_lens.device
assert (
forward_batch.spec_info is None
), "Spec decoding is not supported for NSA backend now"
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
assert forward_batch.seq_lens_cpu is not None
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :max_seqlen_k
]
if forward_batch.forward_mode.is_decode_or_idle():
extend_seq_lens_cpu = [1] * batch_size
max_seqlen_q = 1
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
seqlens_expanded = cache_seqlens_int32
elif forward_batch.forward_mode.is_extend():
assert (
forward_batch.extend_seq_lens_cpu is not None
and forward_batch.extend_seq_lens is not None
and forward_batch.extend_prefix_lens_cpu is not None
), "All of them must not be None"
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
assert forward_batch.extend_seq_lens is not None
if any(forward_batch.extend_prefix_lens_cpu):
max_seqlen_q = max(extend_seq_lens_cpu)
cu_seqlens_q = compute_cu_seqlens(
forward_batch.extend_seq_lens.to(torch.int32)
)
else:
max_seqlen_q = max_seqlen_k
cu_seqlens_q = cu_seqlens_k
seqlens_expanded = torch.cat(
[
torch.arange(
kv_len - qo_len + 1,
kv_len + 1,
dtype=torch.int32,
device=device,
)
for qo_len, kv_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.seq_lens_cpu.tolist(),
strict=True,
)
]
)
else:
assert False, f"Unsupported {forward_batch.forward_mode = }"
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
original_seq_lens=seqlens_expanded,
nsa_index_topk=self.nsa_index_topk,
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table,
flashmla_metadata=(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
),
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=seqlens_expanded,
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
real_page_table=self._transform_table_1_to_real(page_table),
)
self.forward_metadata = metadata
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend.
Args:
max_bs (int): Maximum batch size to support in CUDA graphs
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
self.decode_cuda_graph_metadata: Dict = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
# fake page_table for sparse_prefill
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
"flashmla_metadata": (
self._compute_flashmla_metadata(
cache_seqlens=torch.ones(
max_bs, dtype=torch.int32, device=self.device
),
seq_len_q=1, # TODO handle MTP which is not 1
)
if NSA_DECODE_IMPL == "flashmla_decode"
else None
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Initialize forward metadata for capturing CUDA graph."""
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
# Normal Decode
# Get sequence information
cache_seqlens_int32 = seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
# Use max context length for seq_len_k
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
max_seq_len_k = page_table_1.shape[1]
# Precompute page table
# Precompute cumulative sequence lengths
# NOTE(dark): this is always arange, since we are decoding
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
)
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
real_page_table = self._transform_table_1_to_real(page_table_1)
if NSA_DECODE_IMPL == "flashmla_decode":
flashmla_metadata = self.decode_cuda_graph_metadata[
"flashmla_metadata"
].slice(slice(0, bs + 1))
flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
seq_len_q=1, # TODO handle MTP which is not 1
)
)
else:
flashmla_metadata = None
metadata = NSAMetadata(
page_size=self.real_page_size,
cache_seqlens_int32=cache_seqlens_int32,
max_seq_len_q=1,
max_seq_len_k=max_seq_len_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
page_table_1=page_table_1,
flashmla_metadata=flashmla_metadata,
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
nsa_seqlens_expanded=cache_seqlens_int32,
real_page_table=real_page_table,
nsa_extend_seq_lens_list=[1] * bs,
)
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
):
"""Initialize forward metadata for replaying CUDA graph."""
assert seq_lens_cpu is not None
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
assert (
spec_info is None
), "Speculative decoding is not supported for NSA backend now"
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
# Normal Decode
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
max_len = int(seq_lens_cpu.max().item())
cache_seqlens = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32.copy_(cache_seqlens)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)
page_indices = self.req_to_token[req_pool_indices, :max_len]
metadata.page_table_1[:, :max_len].copy_(page_indices)
assert (
metadata.nsa_cache_seqlens_int32 is not None
and metadata.nsa_cu_seqlens_k is not None
and self.nsa_index_topk is not None
)
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
metadata.nsa_cu_seqlens_k[1:].copy_(
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
)
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
assert self.real_page_size == metadata.page_size
if self.real_page_size > 1:
real_table = self._transform_table_1_to_real(page_indices)
new_len = real_table.shape[1]
metadata.real_page_table[:, :new_len].copy_(real_table)
else:
assert metadata.real_page_table is metadata.page_table_1
if NSA_DECODE_IMPL == "flashmla_decode":
metadata.flashmla_metadata.copy_(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens,
seq_len_q=1, # TODO handle MTP which is not 1
)
)
self.forward_metadata = metadata
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert (
not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
), "NSA backend doesn't support speculative decoding"
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
layer,
cache_loc,
k,
k_rope,
)
metadata = self.forward_metadata
causal = not layer.is_cross_attention
assert causal, "NSA is causal only"
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
# Do absorbed multi-latent attention
assert q_rope is not None
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# when store in fp8 and compute in fp8, no need to convert dtype
if not (NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and NSA_KV_CACHE_STORE_FP8):
kv_cache = kv_cache.to(q.dtype)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
# NOTE(dark): here, we use page size = 1
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
assert metadata.nsa_extend_seq_lens_list is not None
page_table_1 = transform_index_page_table_prefill(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1,
)
# if NSA_PREFILL_IMPL == "tilelang":
# from sglang.srt.layers.attention.nsa.tilelang_kernel import (
# tilelang_sparse_fwd,
# )
# if q_rope is not None:
# q_all = torch.cat([q_nope, q_rope], dim=-1)
# return self._forward_tilelang(
# q_all=q_all,
# kv_cache=kv_cache,
# page_table_1=page_table_1,
# sm_scale=layer.scaling,
# v_head_dim=layer.v_head_dim,
# )
# elif NSA_PREFILL_IMPL == "flashmla_prefill":
# Skip tilelang dependencies
if NSA_PREFILL_IMPL == "tilelang" or NSA_PREFILL_IMPL == "flashmla_prefill":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_PREFILL_IMPL == "flashmla_decode":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
# TODO optimize args
layer=layer,
forward_batch=forward_batch,
metadata=metadata,
topk_indices=topk_indices,
block_table=metadata.real_page_table,
)
elif NSA_PREFILL_IMPL == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
v_head_dim=layer.v_head_dim,
q_nope=q_nope,
page_table=page_table_1,
cache_seqlens=metadata.nsa_cache_seqlens_int32,
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
max_seqlen_q=metadata.nsa_max_seqlen_q,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
page_size=1,
)
else:
raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
topk_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
layer,
cache_loc,
k,
k_rope,
)
metadata = self.forward_metadata
causal = not layer.is_cross_attention
assert causal, "NSA is causal only"
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
page_table_1 = transform_index_page_table_decode(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
page_size=1,
)
if NSA_DECODE_IMPL == "flashmla_prefill":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_prefill(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "flashmla_decode":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_flashmla_decode(
q_all=q_all,
kv_cache=kv_cache,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
# TODO optimize args
layer=layer,
forward_batch=forward_batch,
metadata=metadata,
topk_indices=topk_indices,
block_table=metadata.real_page_table,
)
elif NSA_DECODE_IMPL == "tilelang":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
return self._forward_tilelang(
q_all=q_all,
kv_cache=kv_cache,
page_table_1=page_table_1,
sm_scale=layer.scaling,
v_head_dim=layer.v_head_dim,
)
elif NSA_DECODE_IMPL == "fa3":
return self._forward_fa3(
q_rope=q_rope,
kv_cache=kv_cache,
v_head_dim=layer.v_head_dim,
q_nope=q_nope,
page_table=page_table_1,
cache_seqlens=metadata.nsa_cache_seqlens_int32,
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
max_seqlen_q=metadata.nsa_max_seqlen_q,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
page_size=1,
)
else:
assert False, f"Unsupported {NSA_DECODE_IMPL = }"
def _forward_fa3(
self,
q_rope: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
q_nope: torch.Tensor,
page_table: torch.Tensor,
cache_seqlens: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
sm_scale: float,
logit_cap: float,
page_size: int,
) -> torch.Tensor:
k_rope_cache = kv_cache[:, :, v_head_dim:]
c_kv_cache = kv_cache[:, :, :v_head_dim]
qk_rope_dim = k_rope_cache.shape[-1]
k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
o = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=sm_scale,
causal=True,
softcap=logit_cap,
return_softmax_lse=False,
num_splits=self.num_splits,
)
return o # type: ignore
def _forward_flashmla_prefill(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
page_table_1: torch.Tensor,
sm_scale: float,
) -> torch.Tensor:
#from flash_mla import flash_mla_sparse_fwd
from sglang.srt.layers.attention.native_mla import native_mla_sparse_fwd
_, _, o = native_mla_sparse_fwd(
q=q_all,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
sm_scale=sm_scale,
d_v=v_head_dim,
)
return o
def _forward_flashmla_decode(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
sm_scale: float,
layer,
forward_batch: ForwardBatch,
metadata: NSAMetadata,
topk_indices,
block_table,
) -> torch.Tensor:
#from flash_mla import flash_mla_with_kvcache
from sglang.srt.layers.attention.native_mla import native_mla_with_kvcache
cache_seqlens = metadata.nsa_cache_seqlens_int32
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
assert self.real_page_size == 64, "only page size 64 is supported"
if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not NSA_KV_CACHE_STORE_FP8:
# inefficiently quantize the whole cache
kv_cache = quantize_k_cache(kv_cache)
o, _ = native_mla_with_kvcache(
q=q_all,
k_cache=kv_cache,
cache_seqlens=cache_seqlens,
head_dim_v=v_head_dim,
tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
num_splits=metadata.flashmla_metadata.num_splits,
softmax_scale=sm_scale,
# TODO improve
indices=_compute_indices_in_kvcache(
block_table=block_table,
topk_indices=topk_indices.to(torch.int32),
page_size=self.real_page_size,
nsa_index_topk=self.nsa_index_topk,
),
# doc says it is not used, but if pass in None then error
block_table=torch.empty(
(q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
),
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
)
# TODO shape correct?
return o
def _forward_tilelang(
self,
q_all: torch.Tensor,
kv_cache: torch.Tensor,
v_head_dim: int,
page_table_1: torch.Tensor,
sm_scale: float,
) -> torch.Tensor:
from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
return tilelang_sparse_fwd(
q=q_all,
kv=kv_cache,
indices=page_table_1.unsqueeze(1),
sm_scale=sm_scale,
d_v=v_head_dim,
)
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 1
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> NSAIndexerMetadata:
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
from flash_mla import get_mla_metadata
flashmla_metadata, num_splits = get_mla_metadata(
cache_seqlens=cache_seqlens,
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
# but the name looks like need seq_len_q?
num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
num_heads_k=1,
num_heads_q=self.num_q_heads,
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
topk=self.nsa_index_topk,
)
return NSAFlashMLAMetadata(
flashmla_metadata=flashmla_metadata,
num_splits=num_splits,
)
# TODO speedup
def _compute_indices_in_kvcache(block_table, topk_indices, page_size, nsa_index_topk):
topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
1
)
block_idx = block_table[idx0, topk_indices_safe // page_size]
offset = topk_indices_safe % page_size
indices_in_kvcache = block_idx * page_size + offset
# the kernel requires invalid entry to be -1
assert indices_in_kvcache.shape == topk_indices.shape
indices_in_kvcache[topk_indices == -1] = -1
# return: (batch_size, seqlen_q_ori, topk)
indices_in_kvcache = indices_in_kvcache[:, None, :]
indices_in_kvcache = torch.nn.functional.pad(
indices_in_kvcache,
(0, nsa_index_topk - indices_in_kvcache.shape[-1]),
"constant",
-1,
)
assert indices_in_kvcache.shape[-1] == nsa_index_topk
return indices_in_kvcache

View File

@@ -127,8 +127,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"disable_chunked_prefix_cache"
]
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def _calc_padded_blocks(self, max_seq_len: int) -> int:
"""
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
@@ -219,7 +217,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
@@ -230,9 +228,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info,
)
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
# Custom fast-path for decode/idle.
# Capture with full width so future longer sequences are safe during replay
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
@@ -275,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
):
"""Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
if not forward_mode.is_decode_or_idle():
return super().init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
@@ -287,10 +282,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu,
)
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
metadata = self.decode_cuda_graph_metadata[bs]
# Update block indices for new sequences.
@@ -341,10 +332,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
cum_seq_lens_q,
seq_lens,
)
elif (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
):
elif forward_batch.forward_mode.is_decode_or_idle():
bs = forward_batch.batch_size
# Get maximum sequence length.
@@ -353,19 +341,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
else:
max_seq = forward_batch.seq_lens.max().item()
seq_lens = forward_batch.seq_lens
if forward_batch.forward_mode.is_target_verify():
max_seq = max_seq + self.num_draft_tokens
seq_lens = seq_lens + self.num_draft_tokens
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
seq_lens,
seq_lens.device,
forward_batch.seq_lens,
forward_batch.seq_lens.device,
)
max_seq_len_val = int(max_seq)
@@ -505,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
q_rope_reshaped = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
else:
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
else:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
@@ -568,134 +553,84 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if forward_batch.forward_mode.is_draft_extend():
):
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
if forward_batch.attn_attend_prefix_cache is None:
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
# Save KV cache if requested
if save_kv_cache:
assert (
k is not None and k_rope is not None
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
if q_rope is not None:
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
q = _concat_mla_absorb_q_general(q, q_rope)
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
if forward_batch.forward_mode.is_target_verify():
metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
or self.forward_decode_metadata
)
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
bs = forward_batch.batch_size
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
seq_lens = (
forward_batch.seq_lens.to(torch.int32)
+ forward_batch.spec_info.draft_token_num
)
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
# TODO may use `mla_rope_quantize_fp8` fusion
q = q.to(self.data_type)
assert kv_cache.dtype == self.data_type
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=metadata.block_kv_indices,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=bmm1_scale,
)
# Reshape output directly without slicing
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output
if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert q_rope is None
assert k_rope is None
chunk_idx = forward_batch.prefix_chunk_idx
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
if not forward_batch.attn_attend_prefix_cache:
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
seq_lens=self.forward_prefill_metadata.seq_lens,
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
max_kv_len=self.forward_prefill_metadata.max_seq_len,
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=-1.0,
o_sf_scale=1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
enable_pdl=False,
is_causal=False,
return_lse=True,
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
is_causal=True,
return_lse=forward_batch.mha_return_lse,
)
else:
if not (
forward_batch.attn_attend_prefix_cache is not None
and forward_batch.mha_return_lse
):
output = super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
else:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert q_rope is None
assert k_rope is None
chunk_idx = forward_batch.prefix_chunk_idx
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=self.forward_prefill_metadata.seq_lens,
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=self.forward_prefill_metadata.max_seq_len,
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
enable_pdl=False,
is_causal=True,
return_lse=forward_batch.mha_return_lse,
)
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=-1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
enable_pdl=False,
is_causal=False,
return_lse=True,
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
)
return output
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
@@ -713,10 +648,3 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
kv_indptr_buf=self.kv_indptr[i],
q_indptr_decode_buf=self.q_indptr_decode,
)
def _concat_mla_absorb_q_general(q_nope, q_rope):
if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
return concat_mla_absorb_q(q_nope, q_rope)
else:
return torch.cat([q_nope, q_rope], dim=-1)

View File

@@ -16,19 +16,14 @@ from sglang.srt.utils import (
get_device_capability,
is_blackwell,
is_cuda,
is_npu,
print_info_once,
)
_is_cuda = is_cuda()
_is_npu = is_npu()
if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func
if _is_npu:
import torch_npu
from sglang.srt.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
@@ -336,63 +331,10 @@ class VisionFlash3Attention(nn.Module):
return output
class VisionAscendAttention(nn.Module):
def __init__(
self,
**kwargs,
):
if not _is_npu:
raise Exception("VisionAscendAttention is only available for ascend npu")
super().__init__()
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
bsz: int,
seq_len: int,
**kwargs,
) -> torch.Tensor:
r"""
Args:
cu_seqlens: [b]
Returns:
[b * s, h, head_size]
"""
if cu_seqlens is None:
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
if seq_lens.is_npu:
# cu_seqlens must be on cpu because of operator restriction
seq_lens = seq_lens.to("cpu")
_, num_heads, head_size = q.shape
num_kv_heads = k.shape[1]
output = torch.empty_like(q)
# operator requires pta version >= 2.5.1
torch_npu._npu_flash_attention_unpad(
query=q,
key=k,
value=v,
seq_len=seq_lens.to(torch.int32),
scale_value=head_size**-0.5,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
out=output,
)
return output
QKV_BACKEND_IMPL = {
"triton_attn": VisionTritonAttention,
"sdpa": VisionSdpaAttention,
"fa3": VisionFlash3Attention,
"ascend_attn": VisionAscendAttention,
}

View File

@@ -50,7 +50,6 @@ from sglang.srt.utils import (
is_hip,
is_sm90_supported,
is_sm100_supported,
prepare_weight_cache,
)
_is_flashinfer_available = is_flashinfer_available()
@@ -276,11 +275,7 @@ class LayerCommunicator:
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
cache=None,
):
if cache is not None:
self._context.cache = cache
return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states,
residual=residual,
@@ -354,7 +349,6 @@ class CommunicateContext:
attn_tp_size: int
attn_dp_size: int
tp_size: int
cache = None
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b]
@@ -539,8 +533,6 @@ class CommunicateWithAllReduceAndLayerNormFn:
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if context.cache is not None:
_ = prepare_weight_cache(hidden_states, context.cache)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual

View File

@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
assert len(x.shape) == 2
assert (
x.shape == residual.shape and x.dtype == residual.dtype
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape
if autotune:

View File

@@ -127,69 +127,34 @@ class RMSNorm(CustomOp):
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
# NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous()
if residual is not None:
try:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
residual_out,
residual,
self.weight.data,
self.variance_epsilon,
)
return output, residual_out
except TypeError:
fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
#if _vllm_version < Version("0.9"):
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out
def forward_native(
self,
x: torch.Tensor,

View File

@@ -31,7 +31,6 @@ from sglang.srt.layers.parameter import (
_ColumnvLLMParameter,
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
if TYPE_CHECKING:
@@ -626,16 +625,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[output_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, output_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)
loaded_weight = loaded_weight.narrow(
output_dim, start_idx, shard_size
)
# Special case for AQLM codebooks.
elif is_metadata:
@@ -1310,16 +1302,7 @@ class RowParallelLinear(LinearBase):
shard_size,
)
else:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[input_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, input_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
input_dim, start_idx, shard_size
)
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).

View File

@@ -220,7 +220,6 @@ class LogitsProcessor(nn.Module):
self.config = config
self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
@@ -462,11 +461,7 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"):
if self.use_fp32_lm_head:
logits = torch.matmul(
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
)
elif use_intel_amx_backend(lm_head):
if use_intel_amx_backend(lm_head):
logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
@@ -480,15 +475,7 @@ class LogitsProcessor(nn.Module):
else:
# GGUF models
# TODO: use weight_packed_linear for GGUF models
if self.use_fp32_lm_head:
with torch.cuda.amp.autocast(enabled=False):
logits = lm_head.quant_method.apply(
lm_head, hidden_states.to(torch.float32), embedding_bias
)
else:
logits = lm_head.quant_method.apply(
lm_head, hidden_states, embedding_bias
)
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
if self.logit_scale is not None:
logits.mul_(self.logit_scale)

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Union
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch
import triton
import triton.language as tl
@@ -125,6 +124,7 @@ class EPMoE(FusedMoE):
)
self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
@@ -135,23 +135,11 @@ class EPMoE(FusedMoE):
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
self.use_w4a8_marlin = False
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = False
self.activation_scheme = None
self.use_w4a8_marlin = True
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.use_w4a8_marlin = False
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
@@ -398,11 +386,11 @@ class DeepEPMoE(EPMoE):
return_recv_hook=True,
)
# if self.deepep_mode.enable_low_latency() and not _is_npu:
# # NPU supports low_latency deepep without deepgemm
# assert (
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if _use_aiter:
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
@@ -416,23 +404,23 @@ class DeepEPMoE(EPMoE):
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
# elif not _is_npu:
# self.w13_weight_fp8 = (
# self.w13_weight,
# (
# self.w13_weight_scale_inv
# if self.use_block_quant
# else self.w13_weight_scale
# ),
# )
# self.w2_weight_fp8 = (
# self.w2_weight,
# (
# self.w2_weight_scale_inv
# if self.use_block_quant
# else self.w2_weight_scale
# ),
# )
elif not _is_npu:
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
)
def forward(
self,
@@ -478,15 +466,8 @@ class DeepEPMoE(EPMoE):
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm_contiguous(dispatch_output)
elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
else:
raise ValueError(
f"Dispatch output is not supported"
)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if get_moe_runner_backend().is_flashinfer_cutedsl():
return self.forward_flashinfer_cutedsl(dispatch_output)
@@ -545,34 +526,6 @@ class DeepEPMoE(EPMoE):
expert_mask=self.expert_mask,
)
def forward_deepgemm_w4a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# if num_recv_tokens_per_expert is None:
return hidden_states_int8.bfloat16()
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=topk_weights,
# topk_ids=topk_idx,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None
# # routed_scaling_factor=self.routed_scaling_factor,
# )
# return expert_output
def forward_deepgemm_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
@@ -836,45 +789,69 @@ class DeepEPMoE(EPMoE):
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
else:
# dynamic quant
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
hidden_states.device
)
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
scale=[self.w13_weight_scale.to(output_dtype)],
per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states
)
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states
@@ -883,47 +860,72 @@ class DeepEPMoE(EPMoE):
assert isinstance(dispatch_output, DeepEPLLOutput)
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
if isinstance(hidden_states, tuple):
per_token_scale = hidden_states[1]
hidden_states = hidden_states[0]
group_list = group_list.to(torch.int64)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
if self.w13_weight.dtype != torch.int8:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight.permute(0, 2, 1)],
# per_token_scale=[per_token_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
hidden_states = torch_npu.npu_swiglu(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight.permute(0, 2, 1)],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w13_weight],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32,
)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=self.w13_weight_scale.to(torch.float32),
activation_scale=per_token_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[self.w2_weight],
scale=[self.w2_weight_scale.to(output_dtype)],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=output_dtype,
)[0]
return hidden_states

View File

@@ -1,146 +0,0 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -1,146 +0,0 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -51,14 +51,10 @@ def get_moe_configs(
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
config_dir = os.environ.get(
"SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
)
triton_version = triton.__version__
version_dir = f"triton_{triton_version.replace('.', '_')}"
config_file_path = os.path.join(
config_dir,
os.path.dirname(os.path.realpath(__file__)),
"configs",
version_dir,
json_file_name,
@@ -79,7 +75,7 @@ def get_moe_configs(
if try_triton_version == triton_version:
continue
try_config_file_path = os.path.join(
config_dir,
os.path.dirname(os.path.realpath(__file__)),
"configs",
f"triton_{try_triton_version.replace('.', '_')}",
json_file_name,

View File

@@ -575,10 +575,7 @@ class FusedMoE(torch.nn.Module):
)
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
if (
should_use_flashinfer_trtllm_moe()
and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
):
if should_use_flashinfer_trtllm_moe():
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]

View File

@@ -431,32 +431,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
deepep_post_reorder_triton_kernel,
)
#if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states
# else:
# if hidden_states.shape[0] > 0:
# num_tokens = self.src2dst.shape[0] // self.router_topk
# output = torch.empty(
# (num_tokens, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# deepep_post_reorder_triton_kernel[(num_tokens,)](
# hidden_states,
# output,
# self.src2dst,
# topk_idx,
# topk_weights,
# self.router_topk,
# hidden_states.shape[1],
# BLOCK_SIZE=512,
# )
# else:
# output = torch.zeros(
# (0, hidden_states.shape[1]),
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
output = hidden_states
else:
if hidden_states.shape[0] > 0:
num_tokens = self.src2dst.shape[0] // self.router_topk
output = torch.empty(
(num_tokens, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
deepep_post_reorder_triton_kernel[(num_tokens,)](
hidden_states,
output,
self.src2dst,
topk_idx,
topk_weights,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
else:
output = torch.zeros(
(0, hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
previous_event = Buffer.capture() if self.async_finish else None
return output, previous_event

View File

@@ -7,7 +7,6 @@ from typing import Callable, Optional, Union
import torch
from torch.nn import Parameter
from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu
__all__ = [
@@ -157,17 +156,9 @@ class _ColumnvLLMParameter(BasevLLMParameter):
)
else:
if not use_presharded_weights:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
start_idx = tp_rank * shard_size
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[self.output_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, self.output_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
self.output_dim, start_idx, shard_size
)
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
@@ -267,17 +258,9 @@ class RowvLLMParameter(BasevLLMParameter):
return
else:
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
start_idx = tp_rank * shard_size
end_idx = start_idx + shard_size
if end_idx > loaded_weight.shape[self.input_dim]:
loaded_weight = pad_or_narrow_weight(
loaded_weight, self.input_dim, start_idx, shard_size
)
else:
loaded_weight = loaded_weight.narrow(
self.input_dim, start_idx, shard_size
)
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

View File

@@ -61,7 +61,6 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
_is_mxfp_supported = mxfp_supported()
@@ -87,7 +86,6 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config,
"fbgemm_fp8": FBGEMMFp8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
}

View File

@@ -30,7 +30,6 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
)
from sglang.srt.layers.quantization.compressed_tensors.utils import (

View File

@@ -2,12 +2,10 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsW8A8Fp8",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW8A8Int8",
]

View File

@@ -1,173 +0,0 @@
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Parameter
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import int8_scaled_mm
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per channel
if self.strategy == QuantizationStrategy.TENSOR:
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
if self.input_symmetric:
layer.input_scale = Parameter(
layer.input_scale.max(), requires_grad=False
)
else:
input_scale = layer.input_scale
input_zero_point = layer.input_zero_point
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
layer.input_scale = Parameter(scale, requires_grad=False)
layer.input_zero_point = Parameter(azp, requires_grad=False)
else:
layer.input_scale = None
layer.input_zero_point = None
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
weight = layer.weight
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = layer.input_zero_point * azp_adj
layer.azp_adj = Parameter(azp_adj, requires_grad=False)
else:
layer.azp_adj = None
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition, input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = PerTensorScaleParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x)
return int8_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)

View File

@@ -1,5 +1,7 @@
import logging
import torch
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
logger = logging.getLogger(__name__)
@@ -13,6 +15,7 @@ def _compute_enable_deep_gemm():
try:
import deep_gemm
except ImportError:
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
return False
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")

View File

@@ -843,18 +843,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
topk_weights = topk_weights.to(
torch.float32
) # aiter's moe_sorting requires topk_weights to be FP32
if hasattr(torch, "float4_e2m1fn_x2"):
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
else:
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
output = fused_moe(
x,
w13_weight,
w2_weight,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,

View File

@@ -183,17 +183,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
moe_runner_config = self.moe_runner_config
topk_weights, topk_ids, _ = topk_output
if hasattr(torch, "float4_e2m1fn_x2"):
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
else:
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
output = fused_moe(
x,
w13_weight,
w2_weight,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,

View File

@@ -1,415 +0,0 @@
from typing import Any, Callable, Dict, List, Optional
import torch
from sglang.srt.layers.linear import set_weight_attrs
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import LinearMethodBase, QuantizationConfig, QuantizeMethodBase, FusedMoEMethodBase
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
_ColumnvLLMParameter,
RowvLLMParameter,
)
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from sglang.srt import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
import os
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
scales= scale_a* scale_b.T
gemmout= torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
output = (scales *gemmout).to(out_dtype)
if bias is not None:
output = output + bias
return output.to(out_dtype)
class SlimQuantW4A8Int8Config(QuantizationConfig):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def __init__(self):
pass
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def get_name(self) -> str:
return "slimquant_w4a8"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config":
return cls()
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
self.logical_widths = output_partition_sizes
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
):
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
# assert len(input_quant_args) == 2
# x_q, x_scale = input_quant_args
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
# x_q, x_scale = silu_quant_args
# else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return ops.triton_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,best_config=best_config)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
class SlimQuantW4A8Int8MoEMethod:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=N1//2
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
)

View File

@@ -1,318 +0,0 @@
from typing import Any, Callable, Dict, List, Optional
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
import torch
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import set_weight_attrs
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.quantization.w4a8_utils import w4a8_weight_repack_impl
from sglang.srt.layers.quantization.base_config import (FusedMoEMethodBase, QuantizeMethodBase)
from sglang.srt.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
class MarlinMoeWorkspace:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
"""
_instances = {}
def __new__(cls, device):
if device not in cls._instances:
instance = super().__new__(cls)
instance._initialized = False
cls._instances[device] = instance
return cls._instances[device]
def __init__(self, device):
if self._initialized:
return
sms = torch.cuda.get_device_properties(device).multi_processor_count
self.workspace = torch.zeros(
500, dtype=torch.int, device=device, requires_grad=False
)
self.global_reduce_buffer = torch.zeros(
sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False
)
self._initialized = True
def get_buffers(self):
return self.workspace, self.global_reduce_buffer
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
scales= scale_a* scale_b.T
gemmout= torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
output = (scales *gemmout).to(out_dtype)
if bias is not None:
output = output + bias
return output.to(out_dtype)
class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
"""Config class for W4A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def __init__(self):
pass
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def get_name(self) -> str:
return "slimquant_w4a8_marlin"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig":
return cls()
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]:
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
and user_quant == "slimquant_w4a8_marlin":
return cls.get_name()
return None
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MarlinMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SlimQuantW4A8Int8MarlinMoEMethod:
"""MoE method for W4A8INT8 Marlin.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = intermediate_size_per_partition
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
output = fused_experts_impl_w4a8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=layer.moe_runner_config.activation,
expert_map=layer.expert_map_gpu,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
global_num_experts=layer.moe_runner_config.num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
)
return StandardCombineInput(hidden_states=output)
# def _apply(
# self,
# layer: torch.nn.Module,
# x: torch.Tensor,
# router_logits: torch.Tensor,
# top_k: int,
# #renormalize: bool,
# #use_grouped_topk: bool = False,
# topk_group: Optional[int] = None,
# num_expert_group: Optional[int] = None,
# global_num_experts: int = -1,
# expert_map: Optional[torch.Tensor] = None,
# custom_routing_function: Optional[Callable] = None,
# scoring_func: str = "softmax",
# e_score_correction_bias: Optional[torch.Tensor] = None,
# apply_router_weight_on_input: bool = False,
# activation: str = "silu",
# enable_eplb: bool = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# **_
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# # Expert selection
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=x,
# router_logits=router_logits,
# #use_grouped_topk=use_grouped_topk,
# top_k=top_k,
# #renormalize=renormalize,
# topk_group=topk_group,
# num_expert_group=num_expert_group,
# custom_routing_function=custom_routing_function,
# scoring_func=scoring_func,
# e_score_correction_bias=e_score_correction_bias,
# routed_scaling_factor=routed_scaling_factor,
# use_fused_gate=use_fused_gate
# )
# workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
# return fused_experts_impl_w4a8_marlin(
# x,
# layer.w13_weight,
# layer.w2_weight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# workspace=workspace,
# global_reduce_buffer=global_reduce_buffer,
# inplace=True,
# use_int4_w4a8=True,
# per_channel_quant=True,
# activation=activation,
# expert_map=expert_map,
# apply_router_weight_on_input=apply_router_weight_on_input,
# global_num_experts=global_num_experts,
# w1_scale=(layer.w13_weight_scale),
# w2_scale=(layer.w2_weight_scale),
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# )

View File

@@ -1,92 +0,0 @@
import torch
import numpy as np
try:
from lightop import awq_marlin_repack_w4a8
use_lightop = False
except Exception:
use_lightop = False
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4分别提取到int32的低4位其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2]类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K]类型为torch.int32。
"""
if tensor_int8.dtype != torch.int8:
raise ValueError("Input tensor must be of type torch.int8")
N, K_half = tensor_int8.shape
tensor_uint8 = tensor_int8.to(torch.uint8)
high4 = tensor_uint8 & 0x0F
low4 = (tensor_uint8 >> 4) & 0x0F
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
unpacked[:, 0::2] = low4.to(torch.int32)
unpacked[:, 1::2] = high4.to(torch.int32)
return unpacked
def get_weight_perms(interleave: bool=True):
perm = []
for i in range(64):
for col in range(4):
cur_col = (i % 16) * 4 + col
for row in range(8):
cur_row = (i // 16) * 8 + row
cur_idx = cur_row * 64 + cur_col
perm.append(cur_idx)
perm = np.array(perm)
if interleave:
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
size_k, size_n = q_w.shape
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
orig_device = q_w.device
q_w = q_w.contiguous().to(torch.int32)
M, N = q_w.shape
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
for i in range(pack_factor):
q_packed += q_w[:, i::pack_factor] << (4 * i)
return q_packed
def w4a8_2_marlin_weight(w4a8_w):
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
full_w4a8_w = full_w4a8_w.T
weight_perm = get_weight_perms()
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
return marlin_q_w
def w4a8_weight_repack_impl(input):
if use_lightop:
size_batch = input.shape[0]
size_n = input.shape[1]
size_k = input.shape[2] * 2
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
else:
w_marlin_list = []
for e in range(input.shape[0]):
w_marlin_in = w4a8_2_marlin_weight(input[e])
w_marlin_list.append(w_marlin_in)
output = torch.stack(w_marlin_list, dim=0)
return output

View File

@@ -19,6 +19,10 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import is_npu, set_weight_attrs
_is_npu = is_npu()
if not _is_npu:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE

View File

@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x.dtype,
True, # is_vnni
)
x_q, x_scale = per_token_quant_int8(x)
x_q_2d = x_q.view(-1, x_q.shape[-1])
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
output = int8_scaled_mm(
x_q_2d,
layer.weight,
x_scale_2d,
layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
return int8_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
return output.view(output_shape)
class W8A8Int8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8.
@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8LinearMethodMTImpl:
@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):

View File

@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
get_compiler_backend,
is_cpu,
is_cuda,
is_hip,
@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
else:
FusedSetKVBufferArg = None
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
if _use_aiter:
from aiter.rotary_embedding import get_rope as aiter_get_rope
if is_npu():
import torch_npu
NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for native implementation"
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-npu implementation of forward()."""
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for npu implementation"
import os
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)
return self.forward_native(positions, query, key, offsets)
else:
rotary_mode = "half"
if self.is_neox_style:
@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for cpu implementation"
positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
self.is_neox_style,
)
else:
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)
return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace(
@@ -789,7 +765,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
cos_for_key = cos[:, 0, ...]
sin_for_key = sin[:, 0, ...]
key_rot = key_rot * cos_for_key + rotate_fn(key_rot) * sin_for_key
#key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
@@ -1059,7 +1038,7 @@ class MRotaryEmbedding(RotaryEmbedding):
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
)
@torch.compile(dynamic=True, backend=get_compiler_backend())
@torch.compile(dynamic=True)
def forward(
self,
positions: torch.Tensor,
@@ -1207,7 +1186,7 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
elif model_type == "qwen2_vl":
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
@@ -1918,30 +1897,17 @@ def apply_rotary_pos_emb_npu(
sin: torch.Tensor,
unsqueeze_dim=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
Args:
q: [num_tokens, num_heads, head_size]
k: [num_tokens, num_kv_heads, head_size]
cos: [num_tokens, head_size]
sin: [num_tokens, head_size]
"""
if (
cos.dim() != 2
or q.dim() != 3
or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
):
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
if q.shape[1] != 128:
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
q_embed = q_embed.squeeze(0)
k_embed = k_embed.squeeze(0)
cos = cos.unsqueeze(unsqueeze_dim)
cos = torch.transpose(cos, 1, 2)
sin = sin.unsqueeze(unsqueeze_dim)
sin = torch.transpose(sin, 1, 2)
q = torch.transpose(q, 1, 2)
k = torch.transpose(k, 1, 2)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
q_embed = torch.transpose(q_embed, 1, 2)
k_embed = torch.transpose(k_embed, 1, 2)
return q_embed, k_embed

View File

@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
return None
def pad_or_narrow_weight(
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
) -> torch.Tensor:
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
if valid_size > 0:
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
pad_shape = list(loaded_weight.shape)
pad_shape[input_dim] = shard_size - valid_size
pad = torch.zeros(
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
)
return torch.cat([loaded_slice, pad], dim=input_dim)
# All padding
pad_shape = list(loaded_weight.shape)
pad_shape[input_dim] = shard_size
return torch.zeros(
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
)
class PPMissingLayer(torch.nn.Identity):
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1

View File

@@ -5,7 +5,7 @@ import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import cached_triton_kernel
from sglang.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))

View File

@@ -3,7 +3,7 @@ import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import cached_triton_kernel
from sglang.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))

View File

@@ -275,17 +275,43 @@ class HiCacheController:
and self.storage_config.tp_rank != 0
)
# Use storage backend factory for dynamic backend creation
from sglang.srt.mem_cache.storage import StorageBackendFactory
if storage_backend == "file":
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
try:
self.storage_backend = StorageBackendFactory.create_backend(
storage_backend, self.storage_config, self.mem_pool_host
self.storage_backend = HiCacheFile(self.storage_config)
elif storage_backend == "nixl":
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
self.storage_backend = HiCacheNixl()
elif storage_backend == "mooncake":
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
MooncakeStore,
)
except ValueError as e:
raise ValueError(f"Failed to create storage backend: {e}") from e
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
self.storage_backend = MooncakeStore(self.storage_config)
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs":
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS,
)
if self.mem_pool_host.layout == "page_first":
bytes_per_page = (
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
)
elif self.mem_pool_host.layout == "layer_first":
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config(
bytes_per_page, dtype, self.storage_config
)
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
)
self.enable_storage = True
# todo: threshold policy for prefetching
@@ -309,10 +335,18 @@ class HiCacheController:
# Select the get and set functions
self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set
if self.storage_backend_type in ["hf3fs", "mooncake"]:
self.page_get_func = self._page_get_zero_copy
self.page_set_func = self._page_set_zero_copy
self.batch_exists_func = self.storage_backend.batch_exists
self.is_3fs_zerocopy = (
self.storage_backend_type == "hf3fs"
and self.mem_pool_host.layout == "page_first"
)
if self.storage_backend_type == "mooncake":
self.page_get_func = self._mooncake_page_get
self.page_set_func = self._mooncake_page_set
elif self.is_3fs_zerocopy:
self.page_get_func = self._3fs_zero_copy_page_get
self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.device = self.mem_pool_device.device
self.layer_num = self.mem_pool_device.layer_num
@@ -436,6 +470,7 @@ class HiCacheController:
host_indices = self.mem_pool_host.alloc(len(device_indices))
if host_indices is None:
return None
self.mem_pool_host.protect_write(host_indices)
self.write_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
@@ -459,6 +494,7 @@ class HiCacheController:
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.mem_pool_host.complete_io(op.host_indices)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
@@ -482,6 +518,7 @@ class HiCacheController:
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
self.load_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
@@ -526,6 +563,7 @@ class HiCacheController:
self.io_backend,
)
producer_event.complete(i)
self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
@@ -543,16 +581,29 @@ class HiCacheController:
)
return producer_id
def evict_device(self, device_indices: torch.Tensor) -> int:
self.mem_pool_device_allocator.free(device_indices)
return len(device_indices)
def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int:
if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices)
return len(device_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
if not backup_only:
raise ValueError("Other eviction policies are not supported yet.")
self.mem_pool_host.free(host_indices)
return len(host_indices)
if self.mem_pool_host.is_backup(host_indices):
self.mem_pool_host.free(host_indices)
return len(host_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def prefetch(
self,
@@ -579,19 +630,42 @@ class HiCacheController:
for chunk in chunks:
self.host_mem_release_queue.put(chunk)
def _page_get_zero_copy(self, operation, hash_values, host_indices):
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
inc = 0
for i in range(len(hash_values)):
if not results[i]:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
)
break
inc += self.page_size
operation.increment(inc)
def _3fs_zero_copy_batch_exists(self, batch_hashes):
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
return hit_page_num
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
page_data = self.storage_backend.batch_get(hashes, dsts)
if page_data:
inc = self.page_size * len(hashes) // factor
operation.increment(inc)
else:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
)
def _mooncake_page_get(self, operation, hash_values, host_indices):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
get_result = self.storage_backend.batch_get(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
if get_result != len(hash_values):
logger.warning(
f"Prefetch operation {operation.request_id} failed or partially failed."
)
if get_result != 0:
operation.increment(get_result * self.page_size)
# todo: deprecate
def _generic_page_get(self, operation, hash_values, host_indices):
dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
@@ -681,7 +755,7 @@ class HiCacheController:
batch_tokens[i : i + self.page_size], last_hash
)
batch_hashes.append(last_hash)
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
hit_page_num = self.batch_exists_func(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes):
@@ -750,16 +824,34 @@ class HiCacheController:
self.backup_queue.put(operation)
return operation.id
# todo: deprecate
# non-zero copy
def _generic_page_set(self, hash_values, host_indices) -> bool:
data = [
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values))
]
return self.storage_backend.batch_set(hash_values, data)
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
# zero copy
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
success = self.storage_backend.batch_set(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
return success
# zero copy
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
return self.storage_backend.batch_set(hashes, dsts)
# Backup batch by batch
def _page_backup(self, operation):

View File

@@ -35,7 +35,6 @@ else:
Image = Any
# Parameters for a session
@dataclass
class SessionParams:
id: Optional[str] = None
@@ -133,23 +132,18 @@ class GenerateReqInput:
# Conversation id used for tracking requests
conversation_id: Optional[str] = None
# Label for the request
label: Optional[str] = None
# Priority for the request
priority: Optional[int] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs: bool = False
# For custom metric labels
custom_labels: Optional[Dict[str, str]] = None
# (Deprecated, please use custom_labels) Label for the request
label: Optional[str] = None
# (Internal) Whether to return bytes for image generation
# Image gen grpc migration
return_bytes: bool = False
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
def contains_mm_input(self) -> bool:
return (
has_valid_data(self.image_data)
@@ -548,11 +542,8 @@ class GenerateReqInput:
self.data_parallel_rank if self.data_parallel_rank is not None else None
),
conversation_id=self.conversation_id,
priority=self.priority,
extra_key=self.extra_key,
no_logs=self.no_logs,
custom_labels=self.custom_labels,
label=self.label,
priority=self.priority,
return_bytes=self.return_bytes,
)
@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
# For dp balance
dp_balance_id: int = -1
# Label for the request
label: Optional[str] = None
# Priority for the request
priority: Optional[int] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[str] = None
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs: bool = False
# Image gen grpc migration
return_bytes: bool = False
# tracing context
trace_context: Optional[Dict] = None
# (Deprecated, please use custom_labels) Label for the request
label: Optional[str] = None
# (Internal) Whether to return bytes for image generation
return_bytes: bool = False
@dataclass
class BatchTokenizedGenerateReqInput:

View File

@@ -507,7 +507,6 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False,
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
@@ -523,7 +522,7 @@ def embed_mm_inputs(
Returns:
Combined embedding tensor with multimodal content integrated
"""
other_info = {}
if mm_inputs_list is None:
return None
@@ -533,7 +532,7 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks, deepstack_embeddings = [], [], []
embeddings, masks = [], []
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for modality in Modality.all():
@@ -579,12 +578,6 @@ def embed_mm_inputs(
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
if use_deepstack and embedding is not None:
embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding)
)
deepstack_embeddings += [deepstack_embedding]
embeddings += [embedding]
masks += [mask]
@@ -598,37 +591,13 @@ def embed_mm_inputs(
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
# deepstack embedding
if use_deepstack:
num_deepstack_embeddings = (
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings,
)
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)
other_info["input_deepstack_embeds"] = input_deepstack_embeds
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
for embedding, mask in zip(embeddings, masks):
if embedding is None or mask is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
if use_deepstack:
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype
)
return inputs_embeds, other_info
return inputs_embeds
def general_mm_embed_routine(
@@ -640,7 +609,6 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False,
**kwargs,
) -> torch.Tensor:
"""
@@ -652,7 +620,6 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
**kwargs: Additional arguments passed to language model
Returns:
@@ -678,20 +645,16 @@ def general_mm_embed_routine(
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
inputs_embeds, other_info = embed_mm_inputs(
inputs_embeds = embed_mm_inputs(
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
multimodal_model=multimodal_model,
input_embedding=embed_tokens,
multimodal_model=multimodal_model,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
use_deepstack=use_deepstack,
)
# add for qwen3_vl deepstack
if use_deepstack:
kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None

View File

@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {}
def import_processors(package_name: str):
def import_processors():
package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:

View File

@@ -1,53 +0,0 @@
import torch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.utils import get_compiler_backend
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class FutureMap:
def __init__(
self,
max_running_requests: int,
device: torch.device,
):
self.future_ct = 0
# A factor of 3 is used to avoid collision in the circular buffer.
self.future_limit = max_running_requests * 3
# A factor of 5 is used to ensure the buffer is large enough.
self.future_buffer_len = max_running_requests * 5
self.device = device
self.token_ids_buf = torch.empty(
(self.future_buffer_len,), dtype=torch.int64, device=self.device
)
def update_ct(self, bs: int) -> int:
"""Update the circular buffer pointer and return the current pointer."""
cur_future_ct = self.future_ct
self.future_ct = (cur_future_ct + bs) % self.future_limit
return cur_future_ct
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
input_ids = model_worker_batch.input_ids
_resolve_future_token_ids(input_ids, self.token_ids_buf)
def update_next_future(self, future_ct: int, bs: int):
return torch.arange(
-(future_ct + 1),
-(future_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
)
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids

View File

@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
]
# Put some global args for easy access
@@ -492,7 +493,7 @@ class Req:
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# extra key for classifying the request (e.g. cache_salt)
# extra key for classifying the request (e.g. lora_id, cache_salt)
if lora_id is not None:
extra_key = (
extra_key or ""
@@ -608,8 +609,6 @@ class Req:
) = None
self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
self.output_topk_p = None
self.output_topk_index = None
# Embedding (return values)
self.embedding = None
@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
None
)
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
] = None
# Whether to return hidden states
return_hidden_states: bool = False
@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if (
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_ngram()
or self.spec_algorithm.is_lookahead()
):
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info:
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
has_been_filtered = False
else:
has_been_filtered = True
self.spec_info.filter_batch(
new_indices=keep_indices_device,
has_been_filtered=has_been_filtered,
)
self.spec_info.filter_batch(keep_indices_device)
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
None
)
spec_info: Optional[
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1

View File

@@ -318,6 +318,7 @@ class PrefillAdder:
new_token_ratio: float,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
max_prefill_bs: Optional[int],
mixed_with_decode_tokens: int = 0,
priority_scheduling_preemption_threshold: int = 0,
):
@@ -358,6 +359,10 @@ class PrefillAdder:
priority_scheduling_preemption_threshold
)
self.max_prefill_bs = (
max_prefill_bs if max_prefill_bs is not None else 2147483647
)
def _get_running_request_total_token_offset(self, req: Req) -> int:
return (
min(
@@ -549,6 +554,9 @@ class PrefillAdder:
def add_one_req(
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
):
if len(self.can_run_list) >= self.max_prefill_bs:
return AddReqResult.OTHER
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req)

Some files were not shown because too many files have changed in this diff Show More