Compare commits
3 Commits
v0.5.3_dev
...
0.5.3rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7993ed8ddd | ||
|
|
443a1b4ab3 | ||
|
|
852a49c5cc |
@@ -57,7 +57,7 @@ dependencies = [
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.24",
|
||||
"sgl-kernel==0.3.13",
|
||||
"sgl-kernel==0.3.11",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
@@ -67,7 +67,7 @@ dependencies = [
|
||||
"tiktoken",
|
||||
"anthropic>=0.20.0",
|
||||
"torch_memory_saver==0.0.8",
|
||||
"nvidia-cutlass-dsl==4.2.1",
|
||||
"nvidia-cutlass-dsl==4.2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
|
||||
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
|
||||
"srt/layers/quantization/configs/*.json",
|
||||
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
|
||||
"srt/speculative/cpp_ngram/*.cpp",
|
||||
"srt/speculative/cpp_ngram/*.h",
|
||||
"srt/speculative/cpp_lookahead/*.cpp",
|
||||
"srt/speculative/cpp_lookahead/*.h",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
@@ -65,30 +65,29 @@ tracing = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.13",
|
||||
"sgl-kernel==0.3.11",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.4.0rc1",
|
||||
"flashinfer_python==0.3.1",
|
||||
]
|
||||
|
||||
blackwell = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.3.13",
|
||||
"sgl-kernel==0.3.11",
|
||||
"torch==2.8.0",
|
||||
"torchaudio==2.8.0",
|
||||
"torchvision",
|
||||
"cuda-python",
|
||||
"flashinfer_python==0.4.0rc1",
|
||||
"nvidia-cutlass-dsl==4.2.1",
|
||||
"flashinfer_python==0.3.1",
|
||||
"nvidia-cutlass-dsl==4.2.0",
|
||||
]
|
||||
|
||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
||||
srt_hip = [
|
||||
"sglang[runtime_common]",
|
||||
"torch",
|
||||
"petit_kernel==0.0.2",
|
||||
"wave-lang==3.7.0",
|
||||
]
|
||||
|
||||
@@ -443,9 +443,11 @@ def latency_test_run_once(
|
||||
|
||||
if profile:
|
||||
profiler.stop()
|
||||
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, trace_filename)
|
||||
rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
||||
)
|
||||
|
||||
# Decode
|
||||
decode_latencies = []
|
||||
@@ -477,10 +479,10 @@ def latency_test_run_once(
|
||||
|
||||
if profile and i == output_len / 2:
|
||||
profiler.stop()
|
||||
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, trace_filename)
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
|
||||
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
||||
)
|
||||
|
||||
# Record decode timing from 2nd output
|
||||
|
||||
@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
|
||||
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -20,17 +19,12 @@ import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.bench_serving import (
|
||||
get_tokenizer,
|
||||
sample_mmmu_requests,
|
||||
sample_random_requests,
|
||||
)
|
||||
from sglang.bench_serving import get_tokenizer, sample_random_requests
|
||||
from sglang.profiler import run_profile
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
|
||||
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
||||
|
||||
|
||||
class ProfileLinks(BaseModel):
|
||||
"""Pydantic model for profile trace links."""
|
||||
|
||||
extend: Optional[str] = None
|
||||
decode: Optional[str] = None
|
||||
|
||||
|
||||
class BenchmarkResult(BaseModel):
|
||||
"""Pydantic model for benchmark results table data, for a single isl and osl"""
|
||||
|
||||
model_path: str
|
||||
run_name: str
|
||||
batch_size: int
|
||||
input_len: int
|
||||
output_len: int
|
||||
latency: float
|
||||
ttft: float
|
||||
input_throughput: float
|
||||
output_throughput: float
|
||||
overall_throughput: float
|
||||
last_gen_throughput: float
|
||||
acc_length: Optional[float] = None
|
||||
profile_links: Optional[ProfileLinks] = None
|
||||
|
||||
@staticmethod
|
||||
def help_str() -> str:
|
||||
return f"""
|
||||
Note: To view the traces through perfetto-ui, please:
|
||||
1. open with Google Chrome
|
||||
2. allow popup
|
||||
"""
|
||||
|
||||
def to_markdown_row(
|
||||
self, trace_dir, base_url: str = "", relay_base: str = ""
|
||||
) -> str:
|
||||
"""Convert this benchmark result to a markdown table row."""
|
||||
# Calculate costs (assuming H100 pricing for now)
|
||||
hourly_cost_per_gpu = 2 # $2/hour for one H100
|
||||
hourly_cost = hourly_cost_per_gpu * 1 # Assuming tp_size = 1 for simplicity
|
||||
input_util = 0.7
|
||||
accept_length = (
|
||||
round(self.acc_length, 2) if self.acc_length is not None else "n/a"
|
||||
)
|
||||
itl = 1 / (self.output_throughput / self.batch_size) * 1000
|
||||
input_cost = 1e6 / (self.input_throughput * input_util) / 3600 * hourly_cost
|
||||
output_cost = 1e6 / self.output_throughput / 3600 * hourly_cost
|
||||
|
||||
def get_perfetto_relay_link_from_trace_file(trace_file: str):
|
||||
import os
|
||||
from urllib.parse import quote
|
||||
|
||||
rel_path = os.path.relpath(trace_file, trace_dir)
|
||||
raw_file_link = f"{base_url}/{rel_path}"
|
||||
relay_link = (
|
||||
f"{relay_base}?src={quote(raw_file_link, safe='')}"
|
||||
if relay_base and quote
|
||||
else raw_file_link
|
||||
)
|
||||
return relay_link
|
||||
|
||||
# Handle profile links
|
||||
profile_link = "NA | NA"
|
||||
if self.profile_links:
|
||||
if self.profile_links.extend or self.profile_links.decode:
|
||||
# Create a combined link or use the first available one
|
||||
trace_files = [self.profile_links.extend, self.profile_links.decode]
|
||||
trace_files_relay_links = [
|
||||
f"[trace]({get_perfetto_relay_link_from_trace_file(trace_file)})"
|
||||
for trace_file in trace_files
|
||||
]
|
||||
|
||||
profile_link = " | ".join(trace_files_relay_links)
|
||||
|
||||
# Build the row
|
||||
return f"| {self.batch_size} | {self.input_len} | {self.latency:.2f} | {self.input_throughput:.2f} | {self.output_throughput:.2f} | {accept_length} | {itl:.2f} | {input_cost:.2f} | {output_cost:.2f} | {profile_link} |\n"
|
||||
|
||||
@classmethod
|
||||
def generate_markdown_report(
|
||||
cls, trace_dir, results: List["BenchmarkResult"]
|
||||
) -> str:
|
||||
"""Generate a markdown report from a list of BenchmarkResult object from a single run."""
|
||||
import os
|
||||
|
||||
summary = f"### {results[0].model_path}\n"
|
||||
|
||||
# summary += (
|
||||
# f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
|
||||
# )
|
||||
summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|\n"
|
||||
summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |\n"
|
||||
|
||||
# all results should share the same isl & osl
|
||||
for result in results:
|
||||
base_url = os.getenv("TRACE_BASE_URL", "").rstrip("/")
|
||||
relay_base = os.getenv("PERFETTO_RELAY_URL", "").rstrip("/")
|
||||
relay_base = "https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
|
||||
# base_url = "https://github.com/sgl-project/ci-data/traces"
|
||||
summary += result.to_markdown_row(trace_dir, base_url, relay_base)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "default"
|
||||
@@ -158,12 +50,8 @@ class BenchArgs:
|
||||
profile: bool = False
|
||||
profile_steps: int = 3
|
||||
profile_by_stage: bool = False
|
||||
profile_filename_prefix: str = None
|
||||
append_to_github_summary: bool = True
|
||||
dataset_path: str = ""
|
||||
parallel_batch: bool = False
|
||||
dataset_name: str = "random"
|
||||
output_path: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -179,13 +67,6 @@ class BenchArgs:
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default=BenchArgs.dataset_name,
|
||||
choices=["mmmu", "random"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--client-stream-interval",
|
||||
@@ -215,36 +96,14 @@ class BenchArgs:
|
||||
help="Path to the dataset.",
|
||||
)
|
||||
parser.add_argument("--parallel-batch", action="store_true")
|
||||
parser.add_argument(
|
||||
"--profile-filename-prefix",
|
||||
type=str,
|
||||
default=BenchArgs.profile_filename_prefix,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-append-to-github-summary",
|
||||
action="store_false",
|
||||
dest="append_to_github_summary",
|
||||
help="Disable appending the output of this run to github ci summary",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default=BenchArgs.output_path,
|
||||
help="Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to cast the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
kwargs = {}
|
||||
for attr, attr_type in attrs:
|
||||
val = getattr(args, attr)
|
||||
if attr_type is type(None):
|
||||
kwargs[attr] = val
|
||||
else:
|
||||
kwargs[attr] = attr_type(val)
|
||||
return cls(**kwargs)
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
|
||||
|
||||
def launch_server_internal(server_args):
|
||||
@@ -289,35 +148,23 @@ def run_one_case(
|
||||
run_name: str,
|
||||
result_filename: str,
|
||||
tokenizer,
|
||||
dataset_name="",
|
||||
profile: bool = False,
|
||||
profile_steps: int = 3,
|
||||
profile_by_stage: bool = False,
|
||||
profile_filename_prefix: str = None,
|
||||
dataset_path: str = "",
|
||||
parallel_batch: bool = False,
|
||||
):
|
||||
requests.post(url + "/flush_cache")
|
||||
# TODO: reuse bench_serving.get_dataset ?
|
||||
if dataset_name == "mmmu":
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_requests=batch_size,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=output_len,
|
||||
apply_chat_template=True,
|
||||
random_sample=False,
|
||||
)
|
||||
elif dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
num_prompts=batch_size,
|
||||
range_ratio=1.0,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=dataset_path,
|
||||
random_sample=True,
|
||||
return_text=False,
|
||||
)
|
||||
input_requests = sample_random_requests(
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
num_prompts=batch_size,
|
||||
range_ratio=1.0,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=dataset_path,
|
||||
random_sample=True,
|
||||
return_text=False,
|
||||
)
|
||||
|
||||
use_structured_outputs = False
|
||||
if use_structured_outputs:
|
||||
@@ -334,48 +181,26 @@ def run_one_case(
|
||||
|
||||
profile_link = None
|
||||
if profile:
|
||||
output_dir, profile_name = None, None
|
||||
if profile_filename_prefix:
|
||||
output_dir = os.path.dirname(profile_filename_prefix)
|
||||
profile_name = os.path.basename(profile_filename_prefix)
|
||||
profile_link: str = run_profile(
|
||||
url,
|
||||
profile_steps,
|
||||
["CPU", "GPU"],
|
||||
output_dir,
|
||||
profile_name,
|
||||
profile_by_stage,
|
||||
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
|
||||
payload = {
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"json_schema": json_schema,
|
||||
"stream_interval": stream_interval,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"stream": True,
|
||||
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
|
||||
}
|
||||
if dataset_name == "mmmu":
|
||||
# vlm
|
||||
input_ids = []
|
||||
for input_req in input_requests:
|
||||
input_ids += [tokenizer.encode(input_req.prompt)]
|
||||
payload["image_data"] = [req.image_data for req in input_requests]
|
||||
|
||||
else:
|
||||
input_ids = [req.prompt for req in input_requests]
|
||||
|
||||
payload["input_ids"] = input_ids
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json=payload,
|
||||
json={
|
||||
"input_ids": [req.prompt for req in input_requests],
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"json_schema": json_schema,
|
||||
"stream_interval": stream_interval,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"stream": True,
|
||||
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -439,100 +264,10 @@ def run_one_case(
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
profile_link,
|
||||
profile_link if profile else None,
|
||||
)
|
||||
|
||||
|
||||
def save_results_as_json(result: List[Tuple], bench_args: BenchArgs, model: str):
|
||||
"""Save benchmark results as JSON using Pydantic models."""
|
||||
json_results = []
|
||||
|
||||
# Generate all parameter combinations to match with results
|
||||
param_combinations = list(
|
||||
itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
)
|
||||
)
|
||||
|
||||
for i, (
|
||||
batch_size,
|
||||
latency,
|
||||
ttft,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
profile_link,
|
||||
) in enumerate(result):
|
||||
# Get the corresponding parameters for this result
|
||||
bs, input_len, output_len = param_combinations[i]
|
||||
|
||||
# Parse profile links if available
|
||||
profile_links = None
|
||||
if profile_link:
|
||||
profile_links = parse_profile_links(
|
||||
profile_link, batch_size, input_len, output_len
|
||||
)
|
||||
|
||||
benchmark_result = BenchmarkResult(
|
||||
model_path=model,
|
||||
run_name=bench_args.run_name,
|
||||
batch_size=batch_size,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
latency=latency,
|
||||
ttft=ttft,
|
||||
input_throughput=input_throughput,
|
||||
output_throughput=output_throughput,
|
||||
overall_throughput=overall_throughput,
|
||||
last_gen_throughput=last_gen_throughput,
|
||||
acc_length=acc_length,
|
||||
profile_links=profile_links,
|
||||
)
|
||||
json_results.append(benchmark_result.model_dump())
|
||||
|
||||
# Save to JSON file
|
||||
with open(bench_args.output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(json_results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Results saved as JSON to {bench_args.output_path}")
|
||||
|
||||
|
||||
def parse_profile_links(
|
||||
profile_dir: str, batch_size: int, input_len: int, output_len: int
|
||||
) -> Optional[ProfileLinks]:
|
||||
"""Parse profile directory to extract extend and decode trace file links."""
|
||||
if not profile_dir or not os.path.exists(profile_dir):
|
||||
return None
|
||||
|
||||
extend_link = None
|
||||
decode_link = None
|
||||
|
||||
# Look for extend/prefill trace files
|
||||
for file in os.listdir(profile_dir):
|
||||
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
|
||||
if "extend" in file.lower() or "prefill" in file.lower():
|
||||
extend_link = os.path.join(profile_dir, file)
|
||||
elif "decode" in file.lower():
|
||||
decode_link = os.path.join(profile_dir, file)
|
||||
|
||||
# If no specific extend/decode files found, try to find files with batch/input/output info
|
||||
if not extend_link or not decode_link:
|
||||
for file in os.listdir(profile_dir):
|
||||
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
|
||||
if f"_batch{batch_size}_input{input_len}_output{output_len}_" in file:
|
||||
if "prefill" in file.lower() or "extend" in file.lower():
|
||||
extend_link = os.path.join(profile_dir, file)
|
||||
elif "decode" in file.lower():
|
||||
decode_link = os.path.join(profile_dir, file)
|
||||
|
||||
if extend_link or decode_link:
|
||||
return ProfileLinks(extend=extend_link, decode=decode_link)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_report_summary(
|
||||
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
|
||||
):
|
||||
@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
dataset_name=bench_args.dataset_name,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
tokenizer=tokenizer,
|
||||
@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
dataset_name=bench_args.dataset_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=bench_args.dataset_path,
|
||||
parallel_batch=bench_args.parallel_batch,
|
||||
profile_filename_prefix=bench_args.profile_filename_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
dataset_name=bench_args.dataset_name,
|
||||
profile=bench_args.profile,
|
||||
profile_steps=bench_args.profile_steps,
|
||||
profile_by_stage=bench_args.profile_by_stage,
|
||||
dataset_path=bench_args.dataset_path,
|
||||
parallel_batch=bench_args.parallel_batch,
|
||||
profile_filename_prefix=bench_args.profile_filename_prefix,
|
||||
)[-1],
|
||||
)
|
||||
)
|
||||
@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
|
||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||
|
||||
# Save results as JSON if output_path is specified
|
||||
if bench_args.output_path:
|
||||
save_results_as_json(result, bench_args, model=server_args.model_path)
|
||||
|
||||
if not bench_args.show_report:
|
||||
return
|
||||
|
||||
summary = get_report_summary(result, server_args, bench_args)
|
||||
print(summary)
|
||||
|
||||
if is_in_ci() and bench_args.append_to_github_summary:
|
||||
if is_in_ci():
|
||||
write_github_step_summary(summary)
|
||||
|
||||
|
||||
|
||||
@@ -208,10 +208,6 @@ async def async_request_openai_completions(
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
|
||||
if request_func_input.image_data:
|
||||
payload.update({"image_data": request_func_input.image_data})
|
||||
|
||||
headers = get_auth_headers()
|
||||
|
||||
output = RequestFuncOutput.init_new(request_func_input)
|
||||
@@ -1763,9 +1759,7 @@ async def benchmark(
|
||||
pbar.close()
|
||||
|
||||
if "sglang" in backend:
|
||||
server_info = requests.get(
|
||||
base_url + "/get_server_info", headers=get_auth_headers()
|
||||
)
|
||||
server_info = requests.get(base_url + "/get_server_info")
|
||||
if server_info.status_code == 200:
|
||||
server_info_json = server_info.json()
|
||||
if "decode" in server_info_json:
|
||||
|
||||
@@ -124,8 +124,6 @@ class Envs:
|
||||
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
|
||||
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
|
||||
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
|
||||
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
|
||||
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
|
||||
|
||||
# Model Parallel
|
||||
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
|
||||
@@ -37,8 +37,8 @@ class GlobalConfig:
|
||||
)
|
||||
# Runtime constants: others
|
||||
self.retract_decode_steps = 20
|
||||
self.flashinfer_workspace_size = int(
|
||||
os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
|
||||
self.flashinfer_workspace_size = os.environ.get(
|
||||
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Output tokenization configs
|
||||
|
||||
@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import prepare_server_args
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
MOVE_ENVS_WARN = """
|
||||
########################################################################
|
||||
# For contributors and developers: #
|
||||
# Please move environment variable definitions to sglang.srt.environ #
|
||||
# using the following pattern: #
|
||||
# SGLANG_XXX = EnvBool(False) #
|
||||
# #
|
||||
########################################################################
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
server_args = prepare_server_args(sys.argv[1:])
|
||||
|
||||
from sglang.srt.server_args import print_deprecated_warning
|
||||
|
||||
print_deprecated_warning(MOVE_ENVS_WARN)
|
||||
|
||||
try:
|
||||
launch_server(server_args)
|
||||
finally:
|
||||
|
||||
@@ -5,15 +5,6 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
|
||||
try:
|
||||
from lmslim import quant_ops
|
||||
from lmslim import quant_tools
|
||||
except Exception:
|
||||
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
|
||||
try:
|
||||
import lightop
|
||||
except Exception:
|
||||
print("INFO: Please install lightop if you want to infer awq of marlin.\n")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
use_vllm_custom_allreduce = get_bool_env_var(
|
||||
@@ -184,25 +175,3 @@ def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
|
||||
|
||||
def triton_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
best_config:Optional[list] = None) -> torch.Tensor:
|
||||
|
||||
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
|
||||
|
||||
def triton_int8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_token_act_quant: bool,
|
||||
per_out_channel_weight_quant: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.float16,
|
||||
device: str = "cuda:0",
|
||||
best_config:Optional[list] = None,
|
||||
repeat:Optional[int] = 2):
|
||||
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)
|
||||
@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
|
||||
JAX = "jax"
|
||||
REMOTE = "remote"
|
||||
REMOTE_INSTANCE = "remote_instance"
|
||||
RDMA = "rdma"
|
||||
LOCAL_CACHED = "local_cached"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -49,7 +47,6 @@ class LoadConfig:
|
||||
checkpoints.
|
||||
decryption_key_file: If set, decrypts the output files with a password read
|
||||
from this file (after PBKDF2).
|
||||
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
@@ -57,11 +54,6 @@ class LoadConfig:
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
decryption_key_file: Optional[str] = None
|
||||
decrypt_max_concurrency: int = -1
|
||||
tp_rank: Optional[int] = None
|
||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
|
||||
@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, retry
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
from sglang.utils import is_in_ci
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
|
||||
return (
|
||||
config.architectures is not None
|
||||
and config.architectures[0]
|
||||
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
|
||||
and getattr(config, "index_topk", None) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_head_dim
|
||||
|
||||
|
||||
def get_nsa_index_topk(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_topk
|
||||
|
||||
|
||||
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
|
||||
assert is_deepseek_nsa(config)
|
||||
return config.index_n_heads
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -64,20 +88,35 @@ class ModelConfig:
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
tp_rank: Optional[int] = None,
|
||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
|
||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
|
||||
remote_instance_weight_loader_send_weights_group_ports: Optional[
|
||||
List[int]
|
||||
] = None,
|
||||
) -> None:
|
||||
# Parse args
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.is_draft_model = is_draft_model
|
||||
self.model_impl = model_impl
|
||||
self.tp_rank = tp_rank
|
||||
self.remote_instance_weight_loader_seed_instance_ip = (
|
||||
remote_instance_weight_loader_seed_instance_ip
|
||||
)
|
||||
self.remote_instance_weight_loader_seed_instance_service_port = (
|
||||
remote_instance_weight_loader_seed_instance_service_port
|
||||
)
|
||||
self.remote_instance_weight_loader_send_weights_group_ports = (
|
||||
remote_instance_weight_loader_send_weights_group_ports
|
||||
)
|
||||
|
||||
# Get hf config
|
||||
self._maybe_pull_model_tokenizer_from_remote()
|
||||
self.maybe_pull_model_tokenizer_from_remote()
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
kwargs = {}
|
||||
if override_config_file and override_config_file.strip():
|
||||
kwargs["_configuration_file"] = override_config_file.strip()
|
||||
|
||||
self.hf_config = get_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@@ -85,7 +124,7 @@ class ModelConfig:
|
||||
model_override_args=self.model_override_args,
|
||||
**kwargs,
|
||||
)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
|
||||
self.hf_generation_config = get_generation_config(
|
||||
self.model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@@ -93,25 +132,7 @@ class ModelConfig:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set enable_multimodal
|
||||
if enable_multimodal is None:
|
||||
mm_disabled_models = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
if self.hf_config.architectures[0] in mm_disabled_models:
|
||||
enable_multimodal = False
|
||||
logger.info(
|
||||
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
||||
)
|
||||
else:
|
||||
enable_multimodal = True
|
||||
|
||||
# Config draft model
|
||||
self._config_draft_model()
|
||||
|
||||
# Check model type
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
)
|
||||
@@ -127,70 +148,20 @@ class ModelConfig:
|
||||
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
|
||||
)
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_chunked_prefill_supported = (
|
||||
enable_multimodal
|
||||
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||
)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
# Derive context length and model shapes
|
||||
self._derive_context_length(context_length)
|
||||
self._derive_model_shapes()
|
||||
|
||||
# Verify quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Verify dual-chunk attention config
|
||||
self._verify_dual_chunk_attention_config()
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self._get_hf_eos_token_id()
|
||||
|
||||
# multimodal
|
||||
self.image_token_id = getattr(
|
||||
self.hf_config, "image_token_id", None
|
||||
) or getattr(self.hf_config, "image_token_index", None)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(
|
||||
server_args: ServerArgs,
|
||||
model_path: str = None,
|
||||
model_revision: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
return ModelConfig(
|
||||
model_path=model_path or server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=model_revision or server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _config_draft_model(self):
|
||||
is_draft_model = self.is_draft_model
|
||||
if enable_multimodal is None:
|
||||
mm_disabled_models = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
if self.hf_config.architectures[0] in mm_disabled_models:
|
||||
enable_multimodal = False
|
||||
logger.info(
|
||||
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
||||
)
|
||||
else:
|
||||
enable_multimodal = True
|
||||
|
||||
if (
|
||||
is_draft_model
|
||||
@@ -225,10 +196,31 @@ class ModelConfig:
|
||||
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
|
||||
self.hf_config.num_nextn_predict_layers = 1
|
||||
|
||||
def _derive_context_length(self, context_length: int):
|
||||
is_draft_model = self.is_draft_model
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||
self.hf_config.architectures
|
||||
)
|
||||
self.is_multimodal_chunked_prefill_supported = (
|
||||
enable_multimodal
|
||||
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||
)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
# Derive context length
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
if context_length is not None:
|
||||
if context_length > derived_context_len:
|
||||
reason = "Target model's" if is_draft_model else "User-specified"
|
||||
@@ -242,11 +234,6 @@ class ModelConfig:
|
||||
):
|
||||
logger.warning(msg)
|
||||
self.context_len = context_length
|
||||
if is_draft_model:
|
||||
self.hf_text_config.max_position_embeddings = context_length
|
||||
logger.warning(
|
||||
f"Overriding the draft model's max_position_embeddings to {context_length}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
||||
@@ -256,10 +243,6 @@ class ModelConfig:
|
||||
else:
|
||||
self.context_len = derived_context_len
|
||||
|
||||
# Transfer context_len to HuggingFace config so models can access it
|
||||
self.hf_config.context_len = self.context_len
|
||||
|
||||
def _derive_model_shapes(self):
|
||||
# Unify the config keys for hf_text_config
|
||||
self.head_dim = getattr(
|
||||
self.hf_text_config,
|
||||
@@ -270,6 +253,7 @@ class ModelConfig:
|
||||
# FIXME: temporary special judge for MLA architecture
|
||||
if (
|
||||
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
||||
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
||||
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
||||
@@ -282,6 +266,11 @@ class ModelConfig:
|
||||
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_config.v_head_dim
|
||||
self.index_head_dim = (
|
||||
get_nsa_index_head_dim(self.hf_config)
|
||||
if is_deepseek_nsa(self.hf_config)
|
||||
else None
|
||||
)
|
||||
|
||||
# Handle rope scaling with yarn
|
||||
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
@@ -354,6 +343,45 @@ class ModelConfig:
|
||||
)
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
# Verify quantization
|
||||
self._verify_quantization()
|
||||
|
||||
# Verify dual-chunk attention config
|
||||
self._verify_dual_chunk_attention_config()
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
|
||||
# multimodal
|
||||
self.image_token_id = getattr(
|
||||
self.hf_config, "image_token_id", None
|
||||
) or getattr(self.hf_config, "image_token_index", None)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(
|
||||
server_args: ServerArgs,
|
||||
model_path: str = None,
|
||||
model_revision: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
return ModelConfig(
|
||||
model_path=model_path or server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=model_revision or server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_total_num_attention_heads(self) -> int:
|
||||
return self.num_attention_heads
|
||||
|
||||
@@ -454,31 +482,13 @@ class ModelConfig:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hf_api = HfApi()
|
||||
|
||||
def check_hf_quant_config():
|
||||
return hf_api.file_exists(
|
||||
self.model_path, "hf_quant_config.json"
|
||||
)
|
||||
|
||||
# Retry HF API call up to 3 times
|
||||
file_exists = retry(
|
||||
check_hf_quant_config,
|
||||
max_retry=2,
|
||||
initial_delay=1.0,
|
||||
max_delay=5.0,
|
||||
)
|
||||
|
||||
if file_exists:
|
||||
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
||||
quant_cfg = modelopt_quant_config
|
||||
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
logger.warning(
|
||||
"Offline mode is enabled, skipping hf_quant_config.json check"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to check hf_quant_config.json: {self.model_path} {e}"
|
||||
)
|
||||
pass
|
||||
|
||||
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
||||
quant_config_file = os.path.join(
|
||||
@@ -508,7 +518,6 @@ class ModelConfig:
|
||||
"petit_nvfp4",
|
||||
"quark",
|
||||
"mxfp4",
|
||||
"slimquant_w4a8_marlin",
|
||||
]
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
@@ -527,7 +536,6 @@ class ModelConfig:
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
"petit_nvfp4",
|
||||
"slimquant_w4a8_marlin",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"modelopt_fp4": ["modelopt"],
|
||||
@@ -608,7 +616,7 @@ class ModelConfig:
|
||||
"sparse_attention_enabled"
|
||||
] = True
|
||||
|
||||
def _get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||
if eos_ids is not None:
|
||||
# it can be either int or list of int
|
||||
@@ -628,7 +636,7 @@ class ModelConfig:
|
||||
eos_ids = eos_ids | generation_eos_ids
|
||||
return eos_ids
|
||||
|
||||
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||
"""
|
||||
Pull the model config files to a temporary
|
||||
directory in case of remote.
|
||||
@@ -771,8 +779,6 @@ multimodal_model_archs = [
|
||||
"Qwen2AudioForConditionalGeneration",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"InternVLChatModel",
|
||||
"InternS1ForConditionalGeneration",
|
||||
|
||||
@@ -1,586 +0,0 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class Qwen3VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen3_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=27,
|
||||
hidden_size=1152,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
intermediate_size=4304,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=16,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
out_hidden_size=3584,
|
||||
num_position_embeddings=2304,
|
||||
deepstack_visual_indexes=[8, 16, 24],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.num_position_embeddings = num_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.deepstack_visual_indexes = deepstack_visual_indexes
|
||||
|
||||
|
||||
class Qwen3VLTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a
|
||||
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen3VLModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 22016):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 128000):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 5000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig
|
||||
|
||||
>>> # Initializing a Qwen3VL style configuration
|
||||
>>> configuration = Qwen3VLTextConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-VL-7B style configuration
|
||||
>>> model = Qwen3VLTextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_vl_text"
|
||||
base_config_key = "text_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=4096,
|
||||
intermediate_size=22016,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=128000,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=5000000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
class Qwen3VLConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a
|
||||
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 151655):
|
||||
The image token index to encode the image prompt.
|
||||
video_token_id (`int`, *optional*, defaults to 151656):
|
||||
The video token index to encode the image prompt.
|
||||
vision_start_token_id (`int`, *optional*, defaults to 151652):
|
||||
The start token index to encode the image prompt.
|
||||
vision_end_token_id (`int`, *optional*, defaults to 151653):
|
||||
The end token index to encode the image prompt.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie the word embeddings.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig
|
||||
|
||||
>>> # Initializing a Qwen3-VL style configuration
|
||||
>>> configuration = Qwen3VLConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-VL-4B style configuration
|
||||
>>> model = Qwen3VLForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_vl"
|
||||
sub_configs = {
|
||||
"vision_config": Qwen3VLVisionConfig,
|
||||
"text_config": Qwen3VLTextConfig,
|
||||
}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
vision_end_token_id=151653,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
||||
elif text_config is None:
|
||||
self.text_config = self.sub_configs["text_config"]()
|
||||
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.vision_start_token_id = vision_start_token_id
|
||||
self.vision_end_token_id = vision_end_token_id
|
||||
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
|
||||
|
||||
|
||||
class Qwen3VLMoeTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
|
||||
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2MoeModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 128000):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 5000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1408):
|
||||
Intermediate size of the routed expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 4):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 60):
|
||||
Number of routed experts.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the topk probabilities.
|
||||
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
head_dim (`int`, *optional*):
|
||||
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
|
||||
|
||||
>>> # Initializing a Qwen3VLMoe style configuration
|
||||
>>> configuration = Qwen3VLMoeConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
|
||||
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_vl_moe_text"
|
||||
base_config_key = "text_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen3VLMoe`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=5632,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=128000,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=5000000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=1408,
|
||||
num_experts_per_tok=4,
|
||||
num_experts=60,
|
||||
norm_topk_prob=True,
|
||||
mlp_only_layers=None,
|
||||
rope_scaling=None,
|
||||
head_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
|
||||
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
class Qwen3VLMoeVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen3_vl_moe"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=27,
|
||||
hidden_size=1152,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
intermediate_size=4304,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=16,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=2,
|
||||
out_hidden_size=3584,
|
||||
num_position_embeddings=2304,
|
||||
deepstack_visual_indexes=[8, 16, 24],
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.num_position_embeddings = num_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.deepstack_visual_indexes = deepstack_visual_indexes
|
||||
|
||||
|
||||
class Qwen3VLMoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
|
||||
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 151655):
|
||||
The image token index to encode the image prompt.
|
||||
video_token_id (`int`, *optional*, defaults to 151656):
|
||||
The video token index to encode the image prompt.
|
||||
vision_start_token_id (`int`, *optional*, defaults to 151652):
|
||||
The start token index to encode the image prompt.
|
||||
vision_end_token_id (`int`, *optional*, defaults to 151653):
|
||||
The end token index to encode the image prompt.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie the word embeddings.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
|
||||
|
||||
>>> # Initializing a Qwen3-VL-MOE style configuration
|
||||
>>> configuration = Qwen3VLMoeConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
|
||||
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_vl_moe"
|
||||
sub_configs = {
|
||||
"vision_config": Qwen3VLMoeVisionConfig,
|
||||
"text_config": Qwen3VLMoeTextConfig,
|
||||
}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
vision_end_token_id=151653,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
||||
elif text_config is None:
|
||||
self.text_config = self.sub_configs["text_config"]()
|
||||
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.vision_start_token_id = vision_start_token_id
|
||||
self.vision_end_token_id = vision_end_token_id
|
||||
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Qwen3VLMoeConfig",
|
||||
"Qwen3VLMoeVisionConfig",
|
||||
"Qwen3VLConfig",
|
||||
"Qwen3VLVisionConfig",
|
||||
]
|
||||
@@ -2,9 +2,19 @@ import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
|
||||
try:
|
||||
from mf_adapter import TransferEngine
|
||||
|
||||
import_error = None
|
||||
except ImportError as e:
|
||||
import_error = e
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
||||
def __init__(
|
||||
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
|
||||
):
|
||||
try:
|
||||
from mf_adapter import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
if import_error is not None:
|
||||
logger.warning(
|
||||
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
|
||||
) from e
|
||||
)
|
||||
raise import_error
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = hostname
|
||||
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
||||
self.initialize()
|
||||
|
||||
def initialize(self) -> None:
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
)
|
||||
|
||||
transfer_protocol = self._get_transfer_protocol()
|
||||
if transfer_protocol is None or transfer_protocol == "sdma":
|
||||
trans_op_type = TransferEngine.TransDataOpType.SDMA
|
||||
else:
|
||||
trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
|
||||
"""with device RDMA for PD transfer"""
|
||||
tmp_tensor = torch.zeros(1, device="npu")
|
||||
output_tensor_list = [
|
||||
torch.empty_like(tmp_tensor)
|
||||
for _ in range(get_tensor_model_parallel_world_size())
|
||||
]
|
||||
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
|
||||
torch.distributed.all_gather(
|
||||
output_tensor_list, tmp_tensor, group=get_tp_group().device_group
|
||||
)
|
||||
"""Initialize the ascend transfer instance."""
|
||||
ret_value = self.engine.initialize(
|
||||
self.store_url,
|
||||
self.session_id,
|
||||
self.role,
|
||||
self.npu_id,
|
||||
self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
|
||||
)
|
||||
if ret_value != 0:
|
||||
logger.error("Ascend Transfer Engine initialization failed.")
|
||||
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
||||
ret_value = -1
|
||||
if ret_value != 0:
|
||||
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
|
||||
|
||||
@staticmethod
|
||||
def _get_transfer_protocol():
|
||||
protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
|
||||
allowed_protocols = {"device_rdma", "sdma"}
|
||||
if protocol and protocol.lower() in allowed_protocols:
|
||||
return protocol.lower()
|
||||
else:
|
||||
logger.warning(
|
||||
"Invalid or no transfer protocol specified, using default protocol."
|
||||
)
|
||||
return None
|
||||
@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
|
||||
def _bind_server_socket(self):
|
||||
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
||||
|
||||
@cache
|
||||
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
||||
socket = zmq.Context().socket(zmq.PUSH)
|
||||
if is_ipv6:
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
socket.connect(endpoint)
|
||||
return socket
|
||||
|
||||
def _register_to_bootstrap(self):
|
||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||
if self.dist_init_addr:
|
||||
@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
|
||||
socket.connect(endpoint)
|
||||
return socket
|
||||
|
||||
def get_mha_kv_ptrs_with_pp(
|
||||
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], int]:
|
||||
# pp is not supported on the decode side yet
|
||||
start_layer = self.kv_args.prefill_start_layer
|
||||
num_kv_layers = len(src_kv_ptrs) // 2
|
||||
end_layer = start_layer + num_kv_layers
|
||||
dst_num_total_layers = len(dst_kv_ptrs) // 2
|
||||
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
|
||||
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
|
||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||
dst_v_ptrs = dst_kv_ptrs[
|
||||
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||
]
|
||||
layers_current_pp_stage = len(src_k_ptrs)
|
||||
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
|
||||
|
||||
def get_mla_kv_ptrs_with_pp(
|
||||
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
||||
) -> Tuple[List[int], List[int], int]:
|
||||
# pp is not supported on the decode side yet
|
||||
start_layer = self.kv_args.prefill_start_layer
|
||||
end_layer = start_layer + len(src_kv_ptrs)
|
||||
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||
layers_current_pp_stage = len(src_kv_ptrs)
|
||||
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
|
||||
|
||||
|
||||
class CommonKVSender(BaseKVSender):
|
||||
|
||||
|
||||
@@ -609,21 +609,15 @@ class DecodeTransferQueue:
|
||||
idx = decode_req.metadata_buffer_index
|
||||
(
|
||||
output_id,
|
||||
cached_tokens,
|
||||
output_token_logprobs_val,
|
||||
output_token_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
output_topk_p,
|
||||
output_topk_index,
|
||||
output_hidden_states,
|
||||
) = self.metadata_buffers.get_buf(idx)
|
||||
|
||||
decode_req.req.output_ids.append(output_id[0].item())
|
||||
decode_req.req.cached_tokens = cached_tokens[0].item()
|
||||
if not self.spec_algorithm.is_none():
|
||||
decode_req.req.output_topk_p = output_topk_p
|
||||
decode_req.req.output_topk_index = output_topk_index
|
||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||
if decode_req.req.return_logprob:
|
||||
decode_req.req.output_token_logprobs_val.append(
|
||||
@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||
|
||||
queue_size = (
|
||||
if batch is None and (
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
)
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
||||
|
||||
if batch is None and queue_size == 0:
|
||||
== 0
|
||||
):
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.last_batch = batch
|
||||
@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
queue_size = (
|
||||
if batch is None and (
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
)
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
||||
|
||||
if batch is None and queue_size == 0:
|
||||
== 0
|
||||
):
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.last_batch = batch
|
||||
@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.disagg_decode_transfer_queue.pop_transferred()
|
||||
) # the requests which kv has arrived
|
||||
self.waiting_queue.extend(alloc_reqs)
|
||||
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
self.decode_offload_manager.check_offload_progress()
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import (
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPoolHost,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DecodeKVCacheOffloadManager:
|
||||
"""Manage decode-side KV cache offloading lifecycle and operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
tp_group: torch.distributed.ProcessGroup,
|
||||
tree_cache: BasePrefixCache,
|
||||
server_args: ServerArgs,
|
||||
) -> None:
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = server_args.page_size
|
||||
self.server_args = server_args
|
||||
self.request_counter = 0
|
||||
self.tree_cache = tree_cache
|
||||
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(kv_cache, MHATokenToKVPool):
|
||||
self.decode_host_mem_pool = MHATokenToKVPoolHost(
|
||||
kv_cache,
|
||||
server_args.hicache_ratio,
|
||||
server_args.hicache_size,
|
||||
self.page_size,
|
||||
server_args.hicache_mem_layout,
|
||||
)
|
||||
elif isinstance(kv_cache, MLATokenToKVPool):
|
||||
self.decode_host_mem_pool = MLATokenToKVPoolHost(
|
||||
kv_cache,
|
||||
server_args.hicache_ratio,
|
||||
server_args.hicache_size,
|
||||
self.page_size,
|
||||
server_args.hicache_mem_layout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported KV cache type for decode offload")
|
||||
|
||||
self.tp_group = tp_group
|
||||
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
mem_pool_host=self.decode_host_mem_pool,
|
||||
page_size=self.page_size,
|
||||
tp_group=tp_group,
|
||||
io_backend=server_args.hicache_io_backend,
|
||||
load_cache_event=threading.Event(),
|
||||
storage_backend=server_args.hicache_storage_backend,
|
||||
model_name=server_args.served_model_name,
|
||||
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
||||
)
|
||||
|
||||
self.ongoing_offload = {}
|
||||
self.ongoing_backup = {}
|
||||
logger.info("Enable offload kv cache for decode side")
|
||||
|
||||
def offload_kv_cache(self, req) -> bool:
|
||||
"""Offload a finished request's KV cache to storage."""
|
||||
|
||||
if self.cache_controller is None or self.decode_host_mem_pool is None:
|
||||
return False
|
||||
|
||||
if req.req_pool_idx == -1:
|
||||
return False
|
||||
|
||||
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
||||
if token_indices.dim() == 0 or token_indices.numel() == 0:
|
||||
logger.debug(
|
||||
f"Request {req.rid} has invalid token_indices: {token_indices}"
|
||||
)
|
||||
return False
|
||||
|
||||
tokens = req.origin_input_ids + req.output_ids
|
||||
aligned_len = (len(tokens) // self.page_size) * self.page_size
|
||||
if aligned_len == 0:
|
||||
return False
|
||||
|
||||
token_indices = token_indices[:aligned_len]
|
||||
tokens = tokens[:aligned_len]
|
||||
|
||||
# Asynchronously offload KV cache from device to host by cache controller
|
||||
self.request_counter += 1
|
||||
ack_id = self.request_counter
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=token_indices.long(),
|
||||
node_id=ack_id,
|
||||
)
|
||||
if host_indices is None:
|
||||
logger.error(f"Not enough host memory for request {req.rid}")
|
||||
return False
|
||||
|
||||
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
|
||||
return True
|
||||
|
||||
def check_offload_progress(self):
|
||||
"""Check the progress of offload from device to host and backup from host to storage."""
|
||||
cc = self.cache_controller
|
||||
|
||||
qsizes = torch.tensor(
|
||||
[
|
||||
len(cc.ack_write_queue),
|
||||
cc.ack_backup_queue.qsize(),
|
||||
],
|
||||
dtype=torch.int,
|
||||
)
|
||||
if self.tp_world_size > 1:
|
||||
torch.distributed.all_reduce(
|
||||
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
||||
)
|
||||
|
||||
n_write, n_backup = map(int, qsizes.tolist())
|
||||
self._check_offload_progress(n_write)
|
||||
self._check_backup_progress(n_backup)
|
||||
|
||||
def _check_offload_progress(self, finish_count):
|
||||
"""Check the progress of offload from device to host."""
|
||||
while finish_count > 0:
|
||||
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
||||
finish_event.synchronize()
|
||||
for ack_id in ack_list:
|
||||
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
|
||||
|
||||
# Release device
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
# Trigger async backup from host to storage by cache controller
|
||||
self._trigger_backup(req.rid, host_indices, tokens, start_time)
|
||||
finish_count -= 1
|
||||
|
||||
def _check_backup_progress(self, finish_count):
|
||||
"""Check the progress of backup from host to storage."""
|
||||
for _ in range(finish_count):
|
||||
storage_operation = self.cache_controller.ack_backup_queue.get()
|
||||
ack_id = storage_operation.id
|
||||
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
|
||||
|
||||
# Release host memory
|
||||
self.decode_host_mem_pool.free(host_indices)
|
||||
|
||||
logger.debug(
|
||||
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
|
||||
)
|
||||
|
||||
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
|
||||
"""Trigger async backup from host to storage by cache controller."""
|
||||
|
||||
# Generate page hashes and write to storage
|
||||
page_hashes = self._compute_prefix_hash(tokens)
|
||||
ack_id = self.cache_controller.write_storage(
|
||||
host_indices,
|
||||
tokens,
|
||||
hash_value=page_hashes,
|
||||
)
|
||||
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
|
||||
|
||||
def _compute_prefix_hash(self, tokens):
|
||||
last_hash = ""
|
||||
page_hashes = []
|
||||
for offset in range(0, len(tokens), self.page_size):
|
||||
page_tokens = tokens[offset : offset + self.page_size]
|
||||
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
|
||||
page_hashes.append(last_hash)
|
||||
return page_hashes
|
||||
@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
req.grammar.finished = req.finished()
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
|
||||
# Simulate the eagle run.
|
||||
if self.spec_algorithm.is_eagle():
|
||||
# Simulate the eagle run. We add mock data to hidden states for the
|
||||
# ease of implementation now meaning the first token will have acc rate
|
||||
# of 0.
|
||||
if not self.spec_algorithm.is_none():
|
||||
|
||||
b = len(self.reqs)
|
||||
topk = server_args.speculative_eagle_topk
|
||||
topk_p = torch.stack(
|
||||
[
|
||||
torch.as_tensor(
|
||||
req.output_topk_p[:topk],
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for req in self.reqs
|
||||
],
|
||||
dim=0,
|
||||
topk_p = torch.arange(
|
||||
b * server_args.speculative_eagle_topk,
|
||||
0,
|
||||
-1,
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
topk_index = torch.stack(
|
||||
[
|
||||
torch.as_tensor(
|
||||
req.output_topk_index[:topk],
|
||||
device=self.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
for req in self.reqs
|
||||
],
|
||||
dim=0,
|
||||
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
|
||||
topk_p /= b * server_args.speculative_eagle_topk
|
||||
topk_index = torch.arange(
|
||||
b * server_args.speculative_eagle_topk, device=self.device
|
||||
)
|
||||
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
||||
|
||||
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
||||
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
||||
|
||||
@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
|
||||
layers_params = None
|
||||
|
||||
# pp is not supported on the decode side yet
|
||||
start_layer = self.kv_args.prefill_start_layer
|
||||
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
|
||||
if self.is_mla_backend:
|
||||
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
||||
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||
)
|
||||
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
||||
layers_per_pp_stage = len(src_kv_ptrs)
|
||||
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
layers_params = [
|
||||
(
|
||||
@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
|
||||
dst_kv_ptrs[layer_id],
|
||||
kv_item_len,
|
||||
)
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
for layer_id in range(layers_per_pp_stage)
|
||||
]
|
||||
else:
|
||||
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||
)
|
||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||
dst_num_total_layers = num_kv_layers * self.pp_size
|
||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||
layers_per_pp_stage = len(src_k_ptrs)
|
||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||
dst_v_ptrs = dst_kv_ptrs[
|
||||
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||
]
|
||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
layers_params = [
|
||||
(
|
||||
@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
|
||||
dst_k_ptrs[layer_id],
|
||||
kv_item_len,
|
||||
)
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
for layer_id in range(layers_per_pp_stage)
|
||||
] + [
|
||||
(
|
||||
src_v_ptrs[layer_id],
|
||||
dst_v_ptrs[layer_id],
|
||||
kv_item_len,
|
||||
)
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
for layer_id in range(layers_per_pp_stage)
|
||||
]
|
||||
assert layers_params is not None
|
||||
|
||||
@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
|
||||
num_heads_to_send = dst_heads_per_rank
|
||||
dst_head_start_offset = 0
|
||||
|
||||
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||
)
|
||||
# pp is not supported on the decode side yet
|
||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||
dst_num_total_layers = num_kv_layers * self.pp_size
|
||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||
layers_per_pp_stage = len(src_k_ptrs)
|
||||
start_layer = self.pp_rank * layers_per_pp_stage
|
||||
end_layer = start_layer + layers_per_pp_stage
|
||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||
dst_v_ptrs = dst_kv_ptrs[
|
||||
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||
]
|
||||
|
||||
# Calculate precise byte offset and length for the sub-slice within the token
|
||||
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
||||
@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
|
||||
dst_head_slice_offset,
|
||||
heads_bytes_per_token_to_send,
|
||||
)
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
for layer_id in range(layers_per_pp_stage)
|
||||
] + [
|
||||
(
|
||||
src_v_ptrs[layer_id],
|
||||
@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
|
||||
dst_head_slice_offset,
|
||||
heads_bytes_per_token_to_send,
|
||||
)
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
for layer_id in range(layers_per_pp_stage)
|
||||
]
|
||||
|
||||
def process_layer_tp_aware(layer_params):
|
||||
|
||||
@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
last_hidden_index = (
|
||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||
)
|
||||
req.output_topk_p = batch.spec_info.topk_p[i]
|
||||
req.output_topk_index = batch.spec_info.topk_index[i]
|
||||
if self.spec_algorithm.is_eagle3():
|
||||
req.hidden_states_tensor = (
|
||||
batch.spec_info.hidden_states[i].cpu().clone()
|
||||
|
||||
@@ -85,7 +85,7 @@ class MetadataBuffers:
|
||||
self,
|
||||
size: int,
|
||||
hidden_size: int,
|
||||
hidden_states_dtype: torch.dtype,
|
||||
dtype: torch.dtype,
|
||||
max_top_logprobs_num: int = 128,
|
||||
custom_mem_pool: torch.cuda.MemPool = None,
|
||||
):
|
||||
@@ -107,9 +107,7 @@ class MetadataBuffers:
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||
self.cached_tokens = torch.zeros(
|
||||
(size, 16), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
@@ -122,49 +120,33 @@ class MetadataBuffers:
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||
)
|
||||
# For PD + spec decode
|
||||
self.output_topk_p = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
self.output_topk_index = torch.zeros(
|
||||
(size, 16), dtype=torch.int64, device=device
|
||||
)
|
||||
self.output_hidden_states = torch.zeros(
|
||||
(size, hidden_size), dtype=hidden_states_dtype, device=device
|
||||
(size, hidden_size), dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def get_buf_infos(self):
|
||||
ptrs = [
|
||||
self.output_ids.data_ptr(),
|
||||
self.cached_tokens.data_ptr(),
|
||||
self.output_token_logprobs_val.data_ptr(),
|
||||
self.output_token_logprobs_idx.data_ptr(),
|
||||
self.output_top_logprobs_val.data_ptr(),
|
||||
self.output_top_logprobs_idx.data_ptr(),
|
||||
self.output_topk_p.data_ptr(),
|
||||
self.output_topk_index.data_ptr(),
|
||||
self.output_hidden_states.data_ptr(),
|
||||
]
|
||||
data_lens = [
|
||||
self.output_ids.nbytes,
|
||||
self.cached_tokens.nbytes,
|
||||
self.output_token_logprobs_val.nbytes,
|
||||
self.output_token_logprobs_idx.nbytes,
|
||||
self.output_top_logprobs_val.nbytes,
|
||||
self.output_top_logprobs_idx.nbytes,
|
||||
self.output_topk_p.nbytes,
|
||||
self.output_topk_index.nbytes,
|
||||
self.output_hidden_states.nbytes,
|
||||
]
|
||||
item_lens = [
|
||||
self.output_ids[0].nbytes,
|
||||
self.cached_tokens[0].nbytes,
|
||||
self.output_token_logprobs_val[0].nbytes,
|
||||
self.output_token_logprobs_idx[0].nbytes,
|
||||
self.output_top_logprobs_val[0].nbytes,
|
||||
self.output_top_logprobs_idx[0].nbytes,
|
||||
self.output_topk_p[0].nbytes,
|
||||
self.output_topk_index[0].nbytes,
|
||||
self.output_hidden_states[0].nbytes,
|
||||
]
|
||||
return ptrs, data_lens, item_lens
|
||||
@@ -172,20 +154,16 @@ class MetadataBuffers:
|
||||
def get_buf(self, idx: int):
|
||||
return (
|
||||
self.output_ids[idx],
|
||||
self.cached_tokens[idx],
|
||||
self.output_token_logprobs_val[idx],
|
||||
self.output_token_logprobs_idx[idx],
|
||||
self.output_top_logprobs_val[idx],
|
||||
self.output_top_logprobs_idx[idx],
|
||||
self.output_topk_p[idx],
|
||||
self.output_topk_index[idx],
|
||||
self.output_hidden_states[idx],
|
||||
)
|
||||
|
||||
def set_buf(self, req: Req):
|
||||
|
||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
|
||||
if req.return_logprob:
|
||||
if req.output_token_logprobs_val: # not none or empty list
|
||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||
@@ -208,17 +186,8 @@ class MetadataBuffers:
|
||||
] = torch.tensor(
|
||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
# For PD + spec decode
|
||||
# for PD + spec decode
|
||||
if req.hidden_states_tensor is not None:
|
||||
# speculative_eagle_topk should not be greater than 16 currently
|
||||
topk = req.output_topk_p.size(0)
|
||||
|
||||
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
|
||||
req.output_topk_p
|
||||
)
|
||||
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
|
||||
req.output_topk_index
|
||||
)
|
||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||
req.hidden_states_tensor
|
||||
)
|
||||
|
||||
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
||||
assert_pkg_version(
|
||||
"sgl-kernel",
|
||||
"0.3.12",
|
||||
"0.3.11",
|
||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
@@ -12,8 +11,7 @@ import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import grpc
|
||||
import zmq
|
||||
@@ -81,10 +79,11 @@ class GrpcReqState:
|
||||
last_completion_tokens: int = 1
|
||||
|
||||
# Streaming state
|
||||
last_output_offset: int = 0
|
||||
stream_finished: bool = False
|
||||
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
|
||||
|
||||
# Token accumulation (for non-streaming)
|
||||
# Output accumulation
|
||||
text: str = ""
|
||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
@@ -140,6 +139,8 @@ class GrpcRequestManager:
|
||||
self.is_pause_cond = asyncio.Condition()
|
||||
|
||||
# Metrics
|
||||
self.request_counter = 0
|
||||
self.request_counter_lock = asyncio.Lock()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
# Crash dump for debugging
|
||||
@@ -157,133 +158,22 @@ class GrpcRequestManager:
|
||||
obj: TokenizedGenerateReqInput,
|
||||
request_id: Optional[str] = None,
|
||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
|
||||
) -> asyncio.Queue:
|
||||
"""
|
||||
Submit a generation request to the scheduler with n>1 parallel sampling support.
|
||||
|
||||
This method implements the same two-phase approach as tokenizer_manager.py:
|
||||
1. Phase 1: Send prefix caching request (max_new_tokens=0)
|
||||
2. Phase 2: Send n generation requests that reuse the cached prefix
|
||||
|
||||
Yields individual responses for streaming, or aggregated responses for non-streaming.
|
||||
Submit a generation request to the scheduler.
|
||||
Returns a queue for streaming outputs.
|
||||
"""
|
||||
n = getattr(obj.sampling_params, "n", 1)
|
||||
|
||||
if n <= 1:
|
||||
async for response in self._handle_single_request(
|
||||
obj, request_id, grpc_context
|
||||
):
|
||||
yield response
|
||||
return
|
||||
|
||||
# N>1 handling - two-phase approach
|
||||
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
|
||||
|
||||
# Generate base request ID if not provided
|
||||
if request_id is None:
|
||||
base_request_id = f"grpc-{uuid.uuid4().hex}"
|
||||
else:
|
||||
base_request_id = request_id
|
||||
|
||||
# Phase 1: Cache the common prefix
|
||||
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
|
||||
prefix_obj = copy.copy(obj)
|
||||
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
|
||||
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
|
||||
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
|
||||
|
||||
# Send prefix caching request and consume response
|
||||
async for _ in self._handle_single_request(
|
||||
prefix_obj, f"{base_request_id}-prefix", grpc_context
|
||||
):
|
||||
# Consume prefix response (usually just one chunk with finish_reason)
|
||||
pass
|
||||
|
||||
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
|
||||
|
||||
# Phase 2: Generate n parallel requests
|
||||
logger.debug(f"Phase 2: Generating {n} parallel requests")
|
||||
generators = []
|
||||
request_ids = []
|
||||
|
||||
for i in range(n):
|
||||
# Create individual generation request
|
||||
gen_obj = copy.copy(obj)
|
||||
gen_obj.sampling_params = copy.copy(obj.sampling_params)
|
||||
gen_obj.sampling_params.n = 1 # Each request generates 1 response
|
||||
|
||||
gen_request_id = f"{base_request_id}-{i}"
|
||||
request_ids.append(gen_request_id)
|
||||
|
||||
# Start generation request
|
||||
generators.append(
|
||||
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
|
||||
)
|
||||
|
||||
# Handle response aggregation
|
||||
is_stream = getattr(obj, "stream", False)
|
||||
|
||||
if not is_stream:
|
||||
# Non-streaming: collect all responses and return as batch
|
||||
logger.debug(f"Non-streaming mode: collecting {n} responses")
|
||||
responses = []
|
||||
for generator in generators:
|
||||
async for response in generator:
|
||||
responses.append(response)
|
||||
yield responses # Return all responses as a batch
|
||||
else:
|
||||
# Streaming mode: multiplex responses with index for ordering
|
||||
logger.debug(f"Streaming mode: multiplexing {n} streams")
|
||||
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
|
||||
|
||||
# Create async tasks for all generators
|
||||
task_map = {}
|
||||
for generator in generators:
|
||||
task = asyncio.create_task(generator.__anext__())
|
||||
task_map[task] = generator
|
||||
|
||||
# Process responses as they arrive
|
||||
while task_map:
|
||||
done, _ = await asyncio.wait(
|
||||
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
for task in done:
|
||||
generator = task_map.pop(task)
|
||||
try:
|
||||
response = await task
|
||||
|
||||
# Add index for client-side ordering
|
||||
if isinstance(response, dict) and "meta_info" in response:
|
||||
response_rid = response["meta_info"].get("id", "")
|
||||
if response_rid in rid_to_index:
|
||||
response["index"] = rid_to_index[response_rid]
|
||||
|
||||
yield response
|
||||
|
||||
# Create next task for this generator
|
||||
next_task = asyncio.create_task(generator.__anext__())
|
||||
task_map[next_task] = generator
|
||||
|
||||
except StopAsyncIteration:
|
||||
# This generator is finished
|
||||
pass
|
||||
|
||||
async def _handle_single_request(
|
||||
self,
|
||||
obj: TokenizedGenerateReqInput,
|
||||
request_id: Optional[str] = None,
|
||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||
):
|
||||
"""Handle a single request - core implementation without n>1 logic."""
|
||||
# Generate request ID if not provided
|
||||
if request_id is None:
|
||||
request_id = f"grpc-{uuid.uuid4().hex}"
|
||||
async with self.request_counter_lock:
|
||||
request_id = f"grpc-{self.request_counter}"
|
||||
self.request_counter += 1
|
||||
|
||||
obj.rid = request_id
|
||||
|
||||
# Create and register request state
|
||||
# TODO: support log_request
|
||||
|
||||
# Create request state
|
||||
state = GrpcReqState(
|
||||
request_id=request_id,
|
||||
grpc_context=grpc_context,
|
||||
@@ -299,51 +189,19 @@ class GrpcRequestManager:
|
||||
state.session_id = obj.session_params.session_id
|
||||
state.is_session_request = True
|
||||
|
||||
# Register state
|
||||
self.rid_to_state[request_id] = state
|
||||
self.record_request_for_crash_dump(obj)
|
||||
|
||||
# Send to scheduler via ZMQ
|
||||
try:
|
||||
# Send to scheduler - let exceptions bubble up to grpc_server.py
|
||||
await self._send_to_scheduler(obj)
|
||||
|
||||
is_stream = getattr(obj, "stream", False)
|
||||
|
||||
while True:
|
||||
# Client cancelled - notify scheduler and exit
|
||||
if grpc_context and grpc_context.cancelled():
|
||||
await self.abort_request(request_id)
|
||||
return
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
|
||||
|
||||
if is_stream:
|
||||
yield response
|
||||
|
||||
# Non-streaming: yield final response with accumulated tokens from state
|
||||
if isinstance(response, dict) and response.get("finished", False):
|
||||
if not is_stream:
|
||||
final_response = response.copy()
|
||||
final_response["token_ids"] = state.output_ids
|
||||
yield final_response
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Timeout waiting for response - abort and cleanup
|
||||
logger.warning(
|
||||
f"Timeout waiting for response for request {request_id}"
|
||||
)
|
||||
await self.abort_request(request_id)
|
||||
return
|
||||
|
||||
finally:
|
||||
# Always clean up request state when exiting
|
||||
self._cleanup_request_state(request_id)
|
||||
|
||||
def _cleanup_request_state(self, request_id: str):
|
||||
"""Clean up local request state (does not notify scheduler)."""
|
||||
if request_id in self.rid_to_state:
|
||||
except Exception as e:
|
||||
# Clean up on failure
|
||||
del self.rid_to_state[request_id]
|
||||
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
||||
|
||||
return state.out_queue
|
||||
|
||||
async def embedding_request(
|
||||
self,
|
||||
@@ -356,7 +214,9 @@ class GrpcRequestManager:
|
||||
"""
|
||||
# Generate request ID if not provided
|
||||
if request_id is None:
|
||||
request_id = f"grpc-embed-{uuid.uuid4().hex}"
|
||||
async with self.request_counter_lock:
|
||||
request_id = f"grpc-embed-{self.request_counter}"
|
||||
self.request_counter += 1
|
||||
|
||||
obj.rid = request_id
|
||||
|
||||
@@ -495,6 +355,7 @@ class GrpcRequestManager:
|
||||
# Extract output for this request
|
||||
output_data = {
|
||||
"request_id": rid,
|
||||
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
||||
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
||||
"finished": batch_out.finished_reasons[i] is not None,
|
||||
"meta_info": {
|
||||
@@ -506,9 +367,6 @@ class GrpcRequestManager:
|
||||
if batch_out.completion_tokens
|
||||
else 0
|
||||
),
|
||||
"cached_tokens": (
|
||||
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
|
||||
),
|
||||
"finish_reason": (
|
||||
str(batch_out.finished_reasons[i])
|
||||
if batch_out.finished_reasons[i]
|
||||
@@ -517,110 +375,29 @@ class GrpcRequestManager:
|
||||
},
|
||||
}
|
||||
|
||||
# Accumulate input logprobs (only once, usually in first chunk)
|
||||
if batch_out.input_token_logprobs_val and i < len(
|
||||
batch_out.input_token_logprobs_val
|
||||
):
|
||||
if not state.input_token_logprobs_val:
|
||||
state.input_token_logprobs_val.extend(
|
||||
batch_out.input_token_logprobs_val[i]
|
||||
)
|
||||
if batch_out.input_token_logprobs_idx and i < len(
|
||||
batch_out.input_token_logprobs_idx
|
||||
):
|
||||
state.input_token_logprobs_idx.extend(
|
||||
batch_out.input_token_logprobs_idx[i]
|
||||
)
|
||||
if batch_out.input_top_logprobs_val and i < len(
|
||||
batch_out.input_top_logprobs_val
|
||||
):
|
||||
state.input_top_logprobs_val.extend(
|
||||
batch_out.input_top_logprobs_val[i]
|
||||
)
|
||||
if batch_out.input_top_logprobs_idx and i < len(
|
||||
batch_out.input_top_logprobs_idx
|
||||
):
|
||||
state.input_top_logprobs_idx.extend(
|
||||
batch_out.input_top_logprobs_idx[i]
|
||||
)
|
||||
|
||||
# Send input logprobs based on mode
|
||||
if state.input_token_logprobs_val:
|
||||
if state.obj.stream and not state.input_logprobs_sent:
|
||||
# Streaming: send input logprobs once in first chunk that has them
|
||||
output_data["input_logprobs"] = {
|
||||
"token_logprobs_val": state.input_token_logprobs_val,
|
||||
"token_logprobs_idx": state.input_token_logprobs_idx,
|
||||
"top_logprobs_val": state.input_top_logprobs_val,
|
||||
"top_logprobs_idx": state.input_top_logprobs_idx,
|
||||
}
|
||||
state.input_logprobs_sent = True
|
||||
elif not state.obj.stream and output_data["finished"]:
|
||||
# Non-streaming: send input logprobs in final chunk
|
||||
output_data["input_logprobs"] = {
|
||||
"token_logprobs_val": state.input_token_logprobs_val,
|
||||
"token_logprobs_idx": state.input_token_logprobs_idx,
|
||||
"top_logprobs_val": state.input_top_logprobs_val,
|
||||
"top_logprobs_idx": state.input_top_logprobs_idx,
|
||||
}
|
||||
|
||||
# Add output logprobs if available (RAW - no detokenization!)
|
||||
# Add logprobs if available
|
||||
if batch_out.output_token_logprobs_val and i < len(
|
||||
batch_out.output_token_logprobs_val
|
||||
):
|
||||
# Accumulate in state first
|
||||
state.output_token_logprobs_val.extend(
|
||||
batch_out.output_token_logprobs_val[i]
|
||||
)
|
||||
if batch_out.output_token_logprobs_idx and i < len(
|
||||
batch_out.output_token_logprobs_idx
|
||||
):
|
||||
state.output_token_logprobs_idx.extend(
|
||||
batch_out.output_token_logprobs_idx[i]
|
||||
)
|
||||
if batch_out.output_top_logprobs_val and i < len(
|
||||
batch_out.output_top_logprobs_val
|
||||
):
|
||||
state.output_top_logprobs_val.extend(
|
||||
output_data["logprobs"] = {
|
||||
"tokens": batch_out.output_token_logprobs_val[i],
|
||||
"top_logprobs": (
|
||||
batch_out.output_top_logprobs_val[i]
|
||||
)
|
||||
if batch_out.output_top_logprobs_idx and i < len(
|
||||
batch_out.output_top_logprobs_idx
|
||||
):
|
||||
state.output_top_logprobs_idx.extend(
|
||||
batch_out.output_top_logprobs_idx[i]
|
||||
)
|
||||
if batch_out.output_top_logprobs_val
|
||||
and i < len(batch_out.output_top_logprobs_val)
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
if state.obj.stream:
|
||||
# For streaming: send incremental logprobs (only new tokens in this chunk)
|
||||
# NOTE: this is different than TokenizerManager, which always accumulates
|
||||
def get_part(attr_name):
|
||||
source_list = getattr(batch_out, attr_name, None)
|
||||
return (
|
||||
source_list[i]
|
||||
if source_list and i < len(source_list)
|
||||
else []
|
||||
)
|
||||
# Update state
|
||||
if output_data["text"]:
|
||||
state.text += output_data["text"][state.last_output_offset :]
|
||||
state.last_output_offset = len(output_data["text"])
|
||||
|
||||
output_data["output_logprobs"] = {
|
||||
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
|
||||
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
|
||||
"top_logprobs_val": get_part("output_top_logprobs_val"),
|
||||
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
|
||||
}
|
||||
elif output_data["finished"]:
|
||||
# Non-streaming: send cumulative output logprobs in final chunk
|
||||
output_data["output_logprobs"] = {
|
||||
"token_logprobs_val": state.output_token_logprobs_val,
|
||||
"token_logprobs_idx": state.output_token_logprobs_idx,
|
||||
"top_logprobs_val": state.output_top_logprobs_val,
|
||||
"top_logprobs_idx": state.output_top_logprobs_idx,
|
||||
}
|
||||
|
||||
# Update state for accumulation
|
||||
if output_data["token_ids"]:
|
||||
state.output_ids.extend(output_data["token_ids"])
|
||||
|
||||
# Send to output queue
|
||||
await state.out_queue.put(output_data)
|
||||
|
||||
# Handle completion
|
||||
|
||||
@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
# Convert gRPC request to internal format
|
||||
tokenized_req = self._convert_generate_request(request)
|
||||
|
||||
# Submit to request manager (automatically handles n>1)
|
||||
response_generator = self.request_manager.generate_request(
|
||||
# Submit to request manager
|
||||
output_queue = await self.request_manager.generate_request(
|
||||
obj=tokenized_req,
|
||||
request_id=request.request_id,
|
||||
grpc_context=context,
|
||||
)
|
||||
|
||||
async for output in response_generator:
|
||||
# Handle batch responses (for n>1 non-streaming)
|
||||
if isinstance(output, list):
|
||||
for batch_output in output:
|
||||
if "error" in batch_output:
|
||||
yield sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request.request_id,
|
||||
error=sglang_scheduler_pb2.GenerateError(
|
||||
message=batch_output["error"],
|
||||
http_status_code=(
|
||||
"500" if "abort" not in batch_output else "499"
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# All non-error batch outputs are final responses
|
||||
yield self._create_completion_response(
|
||||
request.request_id, batch_output
|
||||
)
|
||||
else:
|
||||
# Handle single response (for streaming or n=1 non-streaming)
|
||||
# Stream outputs
|
||||
while True:
|
||||
try:
|
||||
# Get output with timeout
|
||||
output = await asyncio.wait_for(output_queue.get(), timeout=4)
|
||||
|
||||
# Check for errors
|
||||
if "error" in output:
|
||||
yield sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request.request_id,
|
||||
@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
),
|
||||
),
|
||||
)
|
||||
elif output.get("finished", False):
|
||||
break
|
||||
|
||||
# Check if finished
|
||||
if output.get("finished", False):
|
||||
# Send completion
|
||||
yield self._create_completion_response(
|
||||
request.request_id, output
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Send chunk
|
||||
yield self._create_chunk_response(request.request_id, output)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Check if context is still active
|
||||
if context.cancelled():
|
||||
# Abort the request
|
||||
await self.request_manager.abort_request(request.request_id)
|
||||
break
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
||||
yield sglang_scheduler_pb2.GenerateResponse(
|
||||
@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
prompt_tokens=result.get("prompt_tokens", 0),
|
||||
cached_tokens=0,
|
||||
embedding_dim=len(result["embedding"]),
|
||||
generation_time=time.time() - self.start_time,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
logger.info(f"Sending health check request to request manager...")
|
||||
|
||||
# Submit and wait for response
|
||||
output_generator = self.request_manager.generate_request(
|
||||
output_queue = await self.request_manager.generate_request(
|
||||
health_request, request_id=rid
|
||||
)
|
||||
|
||||
try:
|
||||
# Get first response with timeout
|
||||
# Wait for response with configurable timeout
|
||||
response = await asyncio.wait_for(
|
||||
output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
|
||||
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
|
||||
)
|
||||
|
||||
# Clean up
|
||||
@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
return_logprob=grpc_req.return_logprob,
|
||||
logprob_start_len=grpc_req.logprob_start_len or -1,
|
||||
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
||||
stream=grpc_req.stream or False,
|
||||
lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
|
||||
stream=True, # Always stream for gRPC
|
||||
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
|
||||
token_ids_logprob=(
|
||||
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
||||
),
|
||||
@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
regex = None
|
||||
json_schema = None
|
||||
ebnf_grammar = None
|
||||
structural_tag = None
|
||||
|
||||
if grpc_params.HasField("regex"):
|
||||
regex = grpc_params.regex
|
||||
@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
json_schema = grpc_params.json_schema
|
||||
elif grpc_params.HasField("ebnf_grammar"):
|
||||
ebnf_grammar = grpc_params.ebnf_grammar
|
||||
elif grpc_params.HasField("structural_tag"):
|
||||
structural_tag = grpc_params.structural_tag
|
||||
|
||||
return SGLSamplingParams(
|
||||
temperature=grpc_params.temperature or 1.0,
|
||||
@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
repetition_penalty=grpc_params.repetition_penalty or 1.0,
|
||||
max_new_tokens=grpc_params.max_new_tokens or 128,
|
||||
min_new_tokens=grpc_params.min_new_tokens or 0,
|
||||
stop=list(grpc_params.stop) if grpc_params.stop else [],
|
||||
stop=list(grpc_params.stop) if grpc_params.stop else None,
|
||||
stop_token_ids=(
|
||||
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
|
||||
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
|
||||
),
|
||||
skip_special_tokens=grpc_params.skip_special_tokens,
|
||||
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
||||
regex=regex,
|
||||
json_schema=json_schema,
|
||||
ebnf=ebnf_grammar,
|
||||
structural_tag=structural_tag,
|
||||
n=grpc_params.n or 1,
|
||||
ignore_eos=grpc_params.ignore_eos,
|
||||
)
|
||||
|
||||
def _convert_logprobs_to_proto(
|
||||
self, logprobs_data: Dict
|
||||
) -> Optional[sglang_scheduler_pb2.LogProbs]:
|
||||
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
|
||||
if not logprobs_data:
|
||||
return None
|
||||
|
||||
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
|
||||
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
|
||||
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
|
||||
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
|
||||
|
||||
# Build TopLogProbs entries
|
||||
top_logprobs_proto = []
|
||||
if top_logprobs_val and top_logprobs_idx:
|
||||
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
|
||||
top_logprobs_proto.append(
|
||||
sglang_scheduler_pb2.TopLogProbs(
|
||||
values=val_list,
|
||||
token_ids=idx_list,
|
||||
)
|
||||
)
|
||||
|
||||
return sglang_scheduler_pb2.LogProbs(
|
||||
token_logprobs=token_logprobs_val,
|
||||
token_ids=token_logprobs_idx,
|
||||
top_logprobs=top_logprobs_proto,
|
||||
)
|
||||
|
||||
def _create_chunk_response(
|
||||
self, request_id: str, output: Dict
|
||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||
"""Create a streaming chunk response."""
|
||||
meta_info = output.get("meta_info", {})
|
||||
|
||||
# Convert output logprobs if present
|
||||
output_logprobs_proto = self._convert_logprobs_to_proto(
|
||||
output.get("output_logprobs")
|
||||
)
|
||||
|
||||
# Convert input logprobs if present (only in first chunk)
|
||||
input_logprobs_proto = self._convert_logprobs_to_proto(
|
||||
output.get("input_logprobs")
|
||||
)
|
||||
|
||||
return sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
||||
token_ids=output.get("token_ids", []),
|
||||
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
||||
completion_tokens=meta_info.get("completion_tokens", 0),
|
||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
||||
output_logprobs=output_logprobs_proto,
|
||||
input_logprobs=input_logprobs_proto,
|
||||
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
|
||||
text=output.get("text", ""),
|
||||
prompt_tokens=0,
|
||||
completion_tokens=len(output.get("token_ids", [])),
|
||||
cached_tokens=0,
|
||||
generation_time=time.time() - self.start_time,
|
||||
queue_time=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||
"""Create a completion response."""
|
||||
|
||||
# Extract meta info and finish reason details
|
||||
# Determine finish reason
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
|
||||
meta_info = output.get("meta_info", {})
|
||||
finish_reason_data = meta_info.get("finish_reason")
|
||||
|
||||
# Determine finish reason, default is stop
|
||||
finish_reason = "stop"
|
||||
if finish_reason_data:
|
||||
if isinstance(finish_reason_data, dict):
|
||||
finish_reason_type = finish_reason_data.get("type")
|
||||
else:
|
||||
# Handle legacy string format
|
||||
finish_reason_type = finish_reason_data
|
||||
|
||||
if finish_reason_type == "length":
|
||||
finish_reason = "length"
|
||||
elif finish_reason_type == "abort":
|
||||
finish_reason = "abort"
|
||||
|
||||
# Extract matched_stop information
|
||||
matched_stop_kwargs = {}
|
||||
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
|
||||
matched = finish_reason_data["matched"]
|
||||
if isinstance(matched, int):
|
||||
matched_stop_kwargs["matched_token_id"] = matched
|
||||
elif isinstance(matched, str):
|
||||
matched_stop_kwargs["matched_stop_str"] = matched
|
||||
|
||||
# Convert output logprobs if present
|
||||
output_logprobs_proto = self._convert_logprobs_to_proto(
|
||||
output.get("output_logprobs")
|
||||
)
|
||||
|
||||
# Convert input logprobs if present
|
||||
input_logprobs_proto = self._convert_logprobs_to_proto(
|
||||
output.get("input_logprobs")
|
||||
)
|
||||
if meta_info.get("finish_reason") == "length":
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
|
||||
elif meta_info.get("finish_reason") == "eos_token":
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
|
||||
|
||||
return sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
complete=sglang_scheduler_pb2.GenerateComplete(
|
||||
output_ids=output.get("token_ids", []),
|
||||
output_text=output.get("text", ""),
|
||||
finish_reason=finish_reason,
|
||||
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
||||
completion_tokens=meta_info.get(
|
||||
"completion_tokens", len(output.get("token_ids", []))
|
||||
),
|
||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
||||
output_logprobs=output_logprobs_proto,
|
||||
input_logprobs=input_logprobs_proto,
|
||||
**matched_stop_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
|
||||
from typing import Any, Dict, List, Optional, TypeAlias, Union
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Cache salt for request caching
|
||||
cache_salt: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# For custom metric labels
|
||||
custom_labels: Optional[Dict[str, str]] = None
|
||||
# For customer metric labels
|
||||
customer_labels: Optional[Dict[str, str]] = None
|
||||
|
||||
@field_validator("max_tokens")
|
||||
@classmethod
|
||||
@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str | Dict[str, Any]] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -392,7 +388,7 @@ class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
strict: bool = False
|
||||
|
||||
@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
# For request id
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
# Cache salt for request caching
|
||||
cache_salt: Optional[Union[List[str], str]] = None
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
|
||||
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
|
||||
)
|
||||
priority: int = Field(default=0, description="Request priority")
|
||||
extra_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Extra key for classifying the request (e.g. cache_salt)",
|
||||
)
|
||||
cache_salt: Optional[str] = Field(
|
||||
default=None, description="Cache salt for request caching"
|
||||
)
|
||||
|
||||
# SGLang-specific sampling parameters
|
||||
frequency_penalty: float = 0.0
|
||||
@@ -943,16 +928,6 @@ class MessageProcessingResult:
|
||||
tool_call_constraint: Optional[Any] = None
|
||||
|
||||
|
||||
class ToolCallProcessingResult(NamedTuple):
|
||||
"""Result of processing tool calls in a response."""
|
||||
|
||||
tool_calls: Optional[
|
||||
List[Any]
|
||||
] # List of ToolCall objects or None if parsing failed
|
||||
remaining_text: str # Text remaining after parsing tool calls
|
||||
finish_reason: Dict[str, Any] # Updated finish reason dictionary
|
||||
|
||||
|
||||
class ResponseReasoningTextContent(BaseModel):
|
||||
text: str
|
||||
type: Literal["reasoning_text"] = "reasoning_text"
|
||||
|
||||
@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
self.allowed_custom_labels = (
|
||||
set(
|
||||
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
|
||||
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
|
||||
)
|
||||
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
|
||||
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
|
||||
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
|
||||
return self.create_error_response(
|
||||
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
||||
)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(
|
||||
message=str(e),
|
||||
err_type="BadRequest",
|
||||
status_code=400,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in request: {e}")
|
||||
return self.create_error_response(
|
||||
@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
|
||||
|
||||
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
|
||||
|
||||
def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
|
||||
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
|
||||
parts = []
|
||||
for key in ["cache_salt", "extra_key"]:
|
||||
value = getattr(request, key, None)
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"Value of {key} must be a string, but got {type(value).__name__}"
|
||||
)
|
||||
parts.append(value)
|
||||
return "".join(parts) if parts else None
|
||||
|
||||
@abstractmethod
|
||||
def _convert_to_internal_request(
|
||||
self,
|
||||
@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
|
||||
)
|
||||
return json.dumps({"error": error.model_dump()})
|
||||
|
||||
def extract_custom_labels(self, raw_request):
|
||||
def extract_customer_labels(self, raw_request):
|
||||
if (
|
||||
not self.allowed_custom_labels
|
||||
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
||||
):
|
||||
return None
|
||||
|
||||
custom_labels = None
|
||||
customer_labels = None
|
||||
header = (
|
||||
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
||||
)
|
||||
@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
|
||||
raw_labels = None
|
||||
|
||||
if isinstance(raw_labels, dict):
|
||||
custom_labels = {
|
||||
customer_labels = {
|
||||
label: value
|
||||
for label, value in raw_labels.items()
|
||||
if label in self.allowed_custom_labels
|
||||
}
|
||||
return custom_labels
|
||||
return customer_labels
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
from jsonschema import Draft202012Validator, SchemaError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
LogProbs,
|
||||
MessageProcessingResult,
|
||||
ToolCall,
|
||||
ToolCallProcessingResult,
|
||||
ToolChoice,
|
||||
TopLogprob,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
|
||||
process_hidden_states_from_ret,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.function_call.core_types import ToolCallItem
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.function_call.json_array_parser import JsonArrayParser
|
||||
from sglang.srt.function_call.utils import get_json_schema_constraint
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.parser.conversation import generate_chat_conv
|
||||
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
||||
@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
super().__init__(tokenizer_manager)
|
||||
self.template_manager = template_manager
|
||||
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "chatcmpl-"
|
||||
@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
):
|
||||
return "Tools cannot be empty if tool choice is set to required."
|
||||
|
||||
if request.tool_choice is not None and not isinstance(request.tool_choice, str):
|
||||
if not request.tools:
|
||||
return "Tools cannot be empty if tool choice is set to a specific tool."
|
||||
tool_name = request.tool_choice.function.name
|
||||
tool_exists = any(tool.function.name == tool_name for tool in request.tools)
|
||||
if not tool_exists:
|
||||
return f"Tool '{tool_name}' not found in tools list."
|
||||
|
||||
# Validate tool definitions
|
||||
for i, tool in enumerate(request.tools or []):
|
||||
if tool.function.parameters is None:
|
||||
continue
|
||||
try:
|
||||
Draft202012Validator.check_schema(tool.function.parameters)
|
||||
except SchemaError as e:
|
||||
return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
|
||||
|
||||
max_output_tokens = request.max_completion_tokens or request.max_tokens
|
||||
server_context_length = self.tokenizer_manager.server_args.context_length
|
||||
if (
|
||||
@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||
|
||||
# Extract custom labels from raw request headers
|
||||
custom_labels = self.extract_custom_labels(raw_request)
|
||||
# Extract customer labels from raw request headers
|
||||
customer_labels = self.extract_customer_labels(raw_request)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
custom_labels=custom_labels,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
tool_call_constraint = parser.get_structure_constraint(
|
||||
request.tool_choice
|
||||
)
|
||||
# Handle JSON schema constraint directly for required or named tool choice
|
||||
if request.tool_choice == "required" or isinstance(
|
||||
request.tool_choice, ToolChoice
|
||||
):
|
||||
json_schema = get_json_schema_constraint(
|
||||
request.tools, request.tool_choice
|
||||
)
|
||||
tool_call_constraint = ("json_schema", json_schema)
|
||||
|
||||
# Use chat template
|
||||
if self.template_manager.chat_template_name is None:
|
||||
@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value.model_dump(by_alias=True)
|
||||
)
|
||||
elif constraint_type == "json_schema":
|
||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||
constraint_value
|
||||
)
|
||||
else:
|
||||
sampling_params[constraint_type] = constraint_value
|
||||
return sampling_params
|
||||
@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
stream_buffers[index] = stream_buffer + delta
|
||||
|
||||
# Handle reasoning content
|
||||
if self.reasoning_parser and request.separate_reasoning:
|
||||
if (
|
||||
self.tokenizer_manager.server_args.reasoning_parser
|
||||
and request.separate_reasoning
|
||||
):
|
||||
reasoning_text, delta = self._process_reasoning_stream(
|
||||
index, delta, reasoning_parser_dict, content, request
|
||||
)
|
||||
@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
# Handle reasoning content
|
||||
reasoning_text = None
|
||||
reasoning_parser = self.reasoning_parser
|
||||
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||
if reasoning_parser and request.separate_reasoning:
|
||||
is_force_reasoning = (
|
||||
self.template_manager.force_reasoning
|
||||
@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
and request.tools
|
||||
and self.tool_call_parser
|
||||
):
|
||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||
text,
|
||||
request.tools,
|
||||
finish_reason,
|
||||
request.tool_choice,
|
||||
history_tool_calls_cnt,
|
||||
text, request.tools, finish_reason
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
|
||||
return ChoiceLogprobs(content=token_logprobs)
|
||||
|
||||
def _process_tool_call_id(
|
||||
self,
|
||||
call_item: ToolCallItem,
|
||||
history_tool_calls_cnt: int,
|
||||
) -> str:
|
||||
"""Process for generating a new and unique `tool_call_id`"""
|
||||
if self.tool_call_parser != "kimi_k2":
|
||||
# A simple uuid is sufficient for all models except for Kimi-K2.
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
return tool_call_id
|
||||
else:
|
||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
|
||||
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
|
||||
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
|
||||
logger.debug(
|
||||
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
|
||||
)
|
||||
return tool_call_id
|
||||
|
||||
def _process_tool_calls(
|
||||
self,
|
||||
text: str,
|
||||
tools: List[Any],
|
||||
finish_reason: Dict[str, Any],
|
||||
tool_choice: Optional[Union[str, ToolChoice]] = None,
|
||||
history_tool_calls_cnt: int = 0,
|
||||
) -> ToolCallProcessingResult:
|
||||
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
||||
"""Process tool calls in the response"""
|
||||
|
||||
# Handle required or named tool choice
|
||||
if tool_choice == "required" or (
|
||||
isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
|
||||
):
|
||||
# Set finish reason to tool_calls since we're processing tool calls
|
||||
if finish_reason["type"] == "stop":
|
||||
finish_reason["type"] = "tool_calls"
|
||||
finish_reason["matched"] = None
|
||||
try:
|
||||
# For required tool choice, we expect a JSON array of tool calls
|
||||
tool_call_data = json.loads(text)
|
||||
tool_calls = []
|
||||
for i, tool in enumerate(tool_call_data):
|
||||
# Create a ToolCallItem from the JSON data
|
||||
call_info = ToolCallItem(
|
||||
tool_index=i, # Use the loop index as tool_index
|
||||
name=tool["name"],
|
||||
parameters=json.dumps(tool["parameters"], ensure_ascii=False),
|
||||
)
|
||||
tool_id = self._process_tool_call_id(
|
||||
call_info, history_tool_calls_cnt
|
||||
)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tool_id,
|
||||
index=i,
|
||||
function=FunctionResponse(
|
||||
name=tool["name"],
|
||||
arguments=json.dumps(
|
||||
tool["parameters"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return ToolCallProcessingResult(tool_calls, "", finish_reason)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Tool call parsing error: {e}")
|
||||
return ToolCallProcessingResult(None, text, finish_reason)
|
||||
|
||||
# Use parser since output is not constrained by JSON schema
|
||||
parser = FunctionCallParser(tools, self.tool_call_parser)
|
||||
if parser.has_tool_call(text):
|
||||
if finish_reason["type"] == "stop":
|
||||
@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = []
|
||||
for call_info in call_info_list:
|
||||
tool_id = self._process_tool_call_id(
|
||||
call_info, history_tool_calls_cnt
|
||||
)
|
||||
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
|
||||
if (
|
||||
self.tool_call_parser == "kimi_k2"
|
||||
and call_info.name is not None
|
||||
):
|
||||
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
|
||||
else:
|
||||
tool_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tool_id,
|
||||
@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
),
|
||||
)
|
||||
)
|
||||
return ToolCallProcessingResult(tool_calls, text, finish_reason)
|
||||
return tool_calls, text, finish_reason
|
||||
except Exception as e:
|
||||
logger.error(f"Tool call parsing error: {e}")
|
||||
# Return error but don't fail the whole request
|
||||
return ToolCallProcessingResult(None, text, finish_reason)
|
||||
return None, text, finish_reason
|
||||
|
||||
return ToolCallProcessingResult(None, text, finish_reason)
|
||||
return None, text, finish_reason
|
||||
|
||||
def _process_streaming_logprobs(
|
||||
self, content: Dict[str, Any], n_prev_token: int
|
||||
@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
or self._get_enable_thinking_from_request(request)
|
||||
)
|
||||
reasoning_parser_dict[index] = ReasoningParser(
|
||||
self.reasoning_parser,
|
||||
self.tokenizer_manager.server_args.reasoning_parser,
|
||||
request.stream_reasoning,
|
||||
is_force_reasoning,
|
||||
)
|
||||
reasoning_parser = reasoning_parser_dict[index]
|
||||
return reasoning_parser.parse_stream_chunk(delta)
|
||||
|
||||
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
|
||||
"""Counts the number of tool calls in the request's message history.
|
||||
|
||||
NOTE: This method is only useful for models that include self-increasing
|
||||
history tool call idx in tool calls id, such as kimi-k2
|
||||
|
||||
Args:
|
||||
request: The chat completion request object.
|
||||
|
||||
Returns:
|
||||
The total number of tool calls in the history, or 0 if not applicable.
|
||||
"""
|
||||
messages = getattr(request, "messages", [])
|
||||
idx = 0
|
||||
for msg in messages:
|
||||
if msg.role == "assistant":
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
||||
return idx
|
||||
|
||||
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
|
||||
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
||||
|
||||
@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
"""
|
||||
if hasattr(request, "chat_template_kwargs") and request.chat_template_kwargs:
|
||||
# For Qwen3 models, `enable_thinking` is supported.
|
||||
if self.reasoning_parser in ["qwen3", "glm45"]:
|
||||
return request.chat_template_kwargs.get("enable_thinking", False)
|
||||
if request.chat_template_kwargs.get("enable_thinking") is not None:
|
||||
return request.chat_template_kwargs.get("enable_thinking")
|
||||
# For DeepSeek-V3.1 models, `thinking` is supported.
|
||||
elif self.reasoning_parser in ["deepseek-v3"]:
|
||||
return request.chat_template_kwargs.get("thinking", False)
|
||||
elif request.chat_template_kwargs.get("thinking") is not None:
|
||||
return request.chat_template_kwargs.get("thinking")
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
@@ -1068,25 +952,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
):
|
||||
"""Process tool calls in streaming response"""
|
||||
if index not in parser_dict:
|
||||
# Use JSON detector directly for required or named tool choice
|
||||
if request.tool_choice == "required" or isinstance(
|
||||
request.tool_choice, ToolChoice
|
||||
):
|
||||
parser_dict[index] = JsonArrayParser()
|
||||
else:
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
tool_call_parser=self.tool_call_parser,
|
||||
)
|
||||
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
tool_call_parser=self.tool_call_parser,
|
||||
)
|
||||
parser = parser_dict[index]
|
||||
|
||||
# Handle both FunctionCallParser and JsonArrayParser
|
||||
if isinstance(parser, JsonArrayParser):
|
||||
result = parser.parse_streaming_increment(delta, request.tools)
|
||||
normal_text, calls = result.normal_text, result.calls
|
||||
else:
|
||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||
|
||||
# Yield normal text
|
||||
if normal_text:
|
||||
@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Yield tool calls
|
||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
||||
for call_item in calls:
|
||||
# Mark that this choice has tool calls
|
||||
has_tool_calls[index] = True
|
||||
@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
# Tool call ID should be generated only once per tool call
|
||||
if call_item.name:
|
||||
# First chunk: include ID and function name
|
||||
tool_call_id = self._process_tool_call_id(
|
||||
call_item, history_tool_calls_cnt
|
||||
)
|
||||
if self.tool_call_parser == "kimi_k2":
|
||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
|
||||
else:
|
||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||
function_name = call_item.name
|
||||
else:
|
||||
# Subsequent chunks: null ID and name for argument deltas
|
||||
@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
def _check_for_unstreamed_tool_args(
|
||||
self,
|
||||
parser: Union[FunctionCallParser, JsonArrayParser],
|
||||
parser: FunctionCallParser,
|
||||
content: Dict[str, Any],
|
||||
request: ChatCompletionRequest,
|
||||
index: int,
|
||||
@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
when generation finishes. This ensures tool calls are properly completed
|
||||
even if the model generates the final arguments in the last chunk.
|
||||
"""
|
||||
# Get the detector - either from FunctionCallParser or directly if json detector
|
||||
detector = parser.detector if hasattr(parser, "detector") else parser
|
||||
|
||||
# Only check if we have tool calls and the detector has tracked data
|
||||
# Only check if we have tool calls and the parser has tracked data
|
||||
if (
|
||||
not hasattr(detector, "prev_tool_call_arr")
|
||||
or not detector.prev_tool_call_arr
|
||||
not hasattr(parser.detector, "prev_tool_call_arr")
|
||||
or not parser.detector.prev_tool_call_arr
|
||||
):
|
||||
return None
|
||||
|
||||
if (
|
||||
not hasattr(detector, "streamed_args_for_tool")
|
||||
or not detector.streamed_args_for_tool
|
||||
not hasattr(parser.detector, "streamed_args_for_tool")
|
||||
or not parser.detector.streamed_args_for_tool
|
||||
):
|
||||
return None
|
||||
|
||||
# Get the last tool call that was being processed
|
||||
tool_index = len(detector.prev_tool_call_arr) - 1
|
||||
if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):
|
||||
tool_index = len(parser.detector.prev_tool_call_arr) - 1
|
||||
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
|
||||
return None
|
||||
|
||||
# Get expected vs actual arguments
|
||||
expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {})
|
||||
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
|
||||
"arguments", {}
|
||||
)
|
||||
expected_call = json.dumps(expected_args, ensure_ascii=False)
|
||||
actual_call = detector.streamed_args_for_tool[tool_index]
|
||||
actual_call = parser.detector.streamed_args_for_tool[tool_index]
|
||||
|
||||
# Check if there are remaining arguments to send
|
||||
remaining_call = (
|
||||
|
||||
@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
|
||||
# Extract custom labels from raw request headers
|
||||
custom_labels = self.extract_custom_labels(raw_request)
|
||||
# Extract customer labels from raw request headers
|
||||
customer_labels = self.extract_customer_labels(raw_request)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
return_hidden_states=request.return_hidden_states,
|
||||
rid=request.rid,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
priority=request.priority,
|
||||
custom_labels=custom_labels,
|
||||
customer_labels=customer_labels,
|
||||
)
|
||||
|
||||
return adapted_request, request
|
||||
|
||||
@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
|
||||
sampling_params=sampling_params,
|
||||
stream=request.stream,
|
||||
rid=request.request_id,
|
||||
extra_key=self._compute_extra_key(request),
|
||||
background=request.background,
|
||||
)
|
||||
|
||||
@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
|
||||
sampling_params=sampling_params,
|
||||
stream=adapted_request.stream,
|
||||
rid=request_id,
|
||||
extra_key=adapted_request.extra_key,
|
||||
return_logprob=adapted_request.return_logprob,
|
||||
logprob_start_len=adapted_request.logprob_start_len,
|
||||
top_logprobs_num=adapted_request.top_logprobs_num,
|
||||
|
||||
@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
|
||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||
logical_to_rank_dispatch_physical_map=(
|
||||
compute_logical_to_rank_dispatch_physical_map(
|
||||
server_args=server_args,
|
||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||
num_gpus=ep_size,
|
||||
num_physical_experts=num_physical_experts,
|
||||
@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
|
||||
|
||||
# TODO optimize performance (rewrite and/or run in separate process with overlap)
|
||||
def compute_logical_to_rank_dispatch_physical_map(
|
||||
server_args: ServerArgs,
|
||||
logical_to_all_physical_map: torch.Tensor,
|
||||
num_gpus: int,
|
||||
num_physical_experts: int,
|
||||
@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
|
||||
):
|
||||
r = random.Random(seed)
|
||||
|
||||
num_local_gpu_physical_experts = num_physical_experts // num_gpus
|
||||
num_gpus_per_node = server_args.ep_size // server_args.nnodes
|
||||
num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
|
||||
num_local_physical_experts = num_physical_experts // num_gpus
|
||||
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
||||
dtype = logical_to_all_physical_map.dtype
|
||||
|
||||
@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
|
||||
physical_expert_id
|
||||
for physical_expert_id in candidate_physical_expert_ids
|
||||
if _compute_gpu_id_of_physical_expert(
|
||||
physical_expert_id, num_local_gpu_physical_experts
|
||||
physical_expert_id, num_local_physical_experts
|
||||
)
|
||||
== gpu_id
|
||||
]
|
||||
if len(same_gpu_physical_expert_ids) > 0:
|
||||
# 1. Prefer same-GPU experts
|
||||
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
|
||||
else:
|
||||
# 2. Otherwise, prefer same-node experts
|
||||
node_id = gpu_id // num_gpus_per_node
|
||||
same_node_physical_expert_ids = [
|
||||
physical_expert_id
|
||||
for physical_expert_id in candidate_physical_expert_ids
|
||||
if _compute_node_id_of_physical_expert(
|
||||
physical_expert_id, num_local_node_physical_experts
|
||||
)
|
||||
== node_id
|
||||
]
|
||||
if len(same_node_physical_expert_ids) > 0:
|
||||
output_partial[gpu_id] = same_node_physical_expert_ids[0]
|
||||
|
||||
# 3. Fill remaining slots with fair random choices
|
||||
num_remain = torch.sum(output_partial == -1).item()
|
||||
output_partial[output_partial == -1] = torch.tensor(
|
||||
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
|
||||
@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
|
||||
|
||||
|
||||
def _compute_gpu_id_of_physical_expert(
|
||||
physical_expert_id: int, num_local_gpu_physical_experts: int
|
||||
physical_expert_id: int, num_local_physical_experts: int
|
||||
) -> int:
|
||||
return physical_expert_id // num_local_gpu_physical_experts
|
||||
|
||||
|
||||
def _compute_node_id_of_physical_expert(
|
||||
physical_expert_id: int, num_local_host_physical_experts: int
|
||||
) -> int:
|
||||
return physical_expert_id // num_local_host_physical_experts
|
||||
return physical_expert_id // num_local_physical_experts
|
||||
|
||||
|
||||
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
|
||||
|
||||
@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
||||
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||
from sglang.srt.function_call.step3_detector import Step3Detector
|
||||
from sglang.srt.function_call.utils import get_json_schema_constraint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -179,8 +178,8 @@ class FunctionCallParser:
|
||||
strict_tag = self.get_structure_tag()
|
||||
return ("structural_tag", strict_tag)
|
||||
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
||||
return ("json_schema", json_schema)
|
||||
ebnf = self.get_ebnf(tool_choice)
|
||||
return ("ebnf", ebnf) if ebnf is not None else None
|
||||
|
||||
def get_ebnf(
|
||||
self, tool_choice: Union[ToolChoice, Literal["required"]]
|
||||
|
||||
@@ -39,7 +39,7 @@ def parse_arguments(json_value):
|
||||
|
||||
class Glm4MoeDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for GLM-4.5 and GLM-4.6 models.
|
||||
Detector for GLM-4.5 models.
|
||||
Assumes function call format:
|
||||
<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>
|
||||
"""
|
||||
@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
|
||||
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
|
||||
"""Check if the text contains a glm-4.5 format tool call."""
|
||||
return self.bot_token in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.
|
||||
Streaming incremental parsing tool calls for GLM-4.5 format.
|
||||
"""
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
||||
from sglang.srt.function_call.core_types import StreamingParseResult
|
||||
|
||||
|
||||
class JsonArrayParser(BaseFormatDetector):
|
||||
"""
|
||||
Parser for JSON array tool calls when JSON schema constraints are active.
|
||||
|
||||
This parser is used when tool_choice="required" or a specific tool is named,
|
||||
bypassing model-specific parsers in favor of direct JSON array parsing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Configure for JSON array parsing
|
||||
self.bot_token = "["
|
||||
self.eot_token = "]"
|
||||
self.tool_call_separator = ","
|
||||
|
||||
def has_tool_call(self, text: str) -> bool:
|
||||
"""
|
||||
Check if the given text contains a JSON tool call (array or single object).
|
||||
"""
|
||||
return "[" in text or "{" in text
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||
"""
|
||||
Parse JSON tool calls using the base class implementation.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Detect and parse not supported for JSON schema constraints."
|
||||
)
|
||||
|
||||
def build_ebnf(self, tools: List[Tool]) -> str:
|
||||
"""
|
||||
Build an EBNF grammar for constrained generation.
|
||||
This is not used for JSON schema constraints as they are handled
|
||||
by the constraint backends directly.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"EBNF generation is not supported for JSON schema constraints."
|
||||
)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Tool]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing with tool validation.
|
||||
"""
|
||||
return super().parse_streaming_increment(new_text, tools)
|
||||
|
||||
def structure_info(self) -> callable:
|
||||
"""
|
||||
Return a function that creates StructureInfo for constrained generation.
|
||||
This is not used for JSON schema constraints as they are handled
|
||||
by the constraint backends directly.
|
||||
"""
|
||||
raise NotImplementedError("structure_info not used for JSON schema constraints")
|
||||
@@ -1,13 +1,10 @@
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from json.decoder import WHITESPACE
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Tuple
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
|
||||
|
||||
|
||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||
prefix = ""
|
||||
@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||
"""
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except (JSONDecodeError, IndexError) as e:
|
||||
msg = getattr(e, "msg", str(e))
|
||||
if "Extra data" in msg or "pop from empty list" in msg:
|
||||
start = WHITESPACE.match(input_str, 0).end()
|
||||
obj, end = JSONDecoder().raw_decode(input_str, start)
|
||||
return obj, end
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_tool_schema_defs(tools: List[Tool]) -> dict:
|
||||
"""
|
||||
Get consolidated $defs from all tools, validating for conflicts.
|
||||
|
||||
Args:
|
||||
tools: List of tools to process
|
||||
|
||||
Returns:
|
||||
Dictionary of consolidated $defs from all tools
|
||||
|
||||
Raises:
|
||||
ValueError: If conflicting $defs are found
|
||||
"""
|
||||
all_defs = {}
|
||||
for tool in tools:
|
||||
if tool.function.parameters is None:
|
||||
continue
|
||||
defs = tool.function.parameters.get("$defs", {})
|
||||
for def_name, def_schema in defs.items():
|
||||
if def_name in all_defs and all_defs[def_name] != def_schema:
|
||||
raise ValueError(
|
||||
f"Tool definition '{def_name}' has "
|
||||
"multiple schemas, which is not "
|
||||
"supported."
|
||||
)
|
||||
else:
|
||||
all_defs[def_name] = def_schema
|
||||
return all_defs
|
||||
|
||||
|
||||
def _get_tool_schema(tool: Tool) -> dict:
|
||||
return {
|
||||
"properties": {
|
||||
"name": {"type": "string", "enum": [tool.function.name]},
|
||||
"parameters": (
|
||||
tool.function.parameters
|
||||
if tool.function.parameters
|
||||
else {"type": "object", "properties": {}}
|
||||
),
|
||||
},
|
||||
"required": ["name", "parameters"],
|
||||
}
|
||||
|
||||
|
||||
def get_json_schema_constraint(
|
||||
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the JSON schema constraint for the specified tool choice.
|
||||
|
||||
Args:
|
||||
tool_choice: The tool choice specification
|
||||
|
||||
Returns:
|
||||
JSON schema dict, or None if no valid tools found
|
||||
"""
|
||||
|
||||
if isinstance(tool_choice, ToolChoice):
|
||||
# For specific function choice, return the user's parameters schema directly
|
||||
fn_name = tool_choice.function.name
|
||||
for tool in tools:
|
||||
if tool.function.name == fn_name:
|
||||
return {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"maxItems": 1,
|
||||
"items": _get_tool_schema(tool),
|
||||
}
|
||||
return None
|
||||
elif tool_choice == "required":
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [_get_tool_schema(tool) for tool in tools],
|
||||
},
|
||||
}
|
||||
json_schema_defs = _get_tool_schema_defs(tools)
|
||||
if json_schema_defs:
|
||||
json_schema["$defs"] = json_schema_defs
|
||||
return json_schema
|
||||
|
||||
return None
|
||||
|
||||
@@ -36,9 +36,9 @@ message SamplingParams {
|
||||
float presence_penalty = 6;
|
||||
float repetition_penalty = 7;
|
||||
|
||||
optional int32 max_new_tokens = 8;
|
||||
int32 max_new_tokens = 8;
|
||||
repeated string stop = 9;
|
||||
repeated uint32 stop_token_ids = 10;
|
||||
repeated int32 stop_token_ids = 10;
|
||||
bool skip_special_tokens = 11;
|
||||
bool spaces_between_special_tokens = 12;
|
||||
|
||||
@@ -47,24 +47,24 @@ message SamplingParams {
|
||||
string regex = 13;
|
||||
string json_schema = 14;
|
||||
string ebnf_grammar = 15;
|
||||
string structural_tag = 16;
|
||||
}
|
||||
|
||||
// LoRA adapter
|
||||
string lora_path = 17;
|
||||
string lora_path = 16;
|
||||
|
||||
// Speculative decoding
|
||||
int32 n = 18; // Number of samples
|
||||
int32 n = 17; // Number of samples
|
||||
|
||||
// Token healing
|
||||
bool token_healing = 19;
|
||||
bool token_healing = 18;
|
||||
|
||||
// Additional parameters
|
||||
int32 min_new_tokens = 20;
|
||||
bool ignore_eos = 21;
|
||||
bool no_stop_trim = 22;
|
||||
int32 stream_interval = 23;
|
||||
map<string, float> logit_bias = 24;
|
||||
int32 min_new_tokens = 19;
|
||||
bool ignore_eos = 20;
|
||||
bool no_stop_trim = 21;
|
||||
int32 stream_interval = 22;
|
||||
map<string, float> logit_bias = 23;
|
||||
string structural_tag = 24;
|
||||
|
||||
// Custom parameters for extensibility
|
||||
google.protobuf.Struct custom_params = 25;
|
||||
@@ -98,7 +98,7 @@ message GenerateRequest {
|
||||
bool return_logprob = 5;
|
||||
int32 logprob_start_len = 6;
|
||||
int32 top_logprobs_num = 7;
|
||||
repeated uint32 token_ids_logprob = 8;
|
||||
repeated int32 token_ids_logprob = 8;
|
||||
bool return_hidden_states = 9;
|
||||
|
||||
// For disaggregated serving
|
||||
@@ -122,14 +122,11 @@ message GenerateRequest {
|
||||
|
||||
// For load balancing
|
||||
int32 dp_balance_id = 17;
|
||||
|
||||
// Whether client wants streaming response
|
||||
bool stream = 18;
|
||||
}
|
||||
|
||||
message TokenizedInput {
|
||||
string original_text = 1; // For reference
|
||||
repeated uint32 input_ids = 2;
|
||||
repeated int32 input_ids = 2;
|
||||
}
|
||||
|
||||
message MultimodalInputs {
|
||||
@@ -166,50 +163,51 @@ message GenerateResponse {
|
||||
}
|
||||
|
||||
message GenerateStreamChunk {
|
||||
// Generated tokens (incremental chunk)
|
||||
repeated uint32 token_ids = 1;
|
||||
// Generated token
|
||||
int32 token_id = 1;
|
||||
string text = 2;
|
||||
|
||||
// Cumulative counts
|
||||
int32 prompt_tokens = 2;
|
||||
int32 completion_tokens = 3;
|
||||
int32 cached_tokens = 4;
|
||||
|
||||
// Output logprobs (if requested) - incremental for streaming
|
||||
LogProbs output_logprobs = 5;
|
||||
|
||||
// Hidden states (if requested)
|
||||
repeated float hidden_states = 6;
|
||||
|
||||
// Input logprobs (if requested) - only in first chunk
|
||||
LogProbs input_logprobs = 7;
|
||||
}
|
||||
|
||||
message GenerateComplete {
|
||||
// Final output
|
||||
repeated uint32 output_ids = 1;
|
||||
|
||||
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
|
||||
string finish_reason = 2;
|
||||
|
||||
// Token usage counts
|
||||
int32 prompt_tokens = 3;
|
||||
int32 completion_tokens = 4;
|
||||
int32 cached_tokens = 5;
|
||||
|
||||
// Output logprobs if requested (cumulative)
|
||||
LogProbs output_logprobs = 6;
|
||||
// Logprobs (if requested)
|
||||
LogProbs logprobs = 6;
|
||||
|
||||
// Hidden states (if requested)
|
||||
repeated float hidden_states = 7;
|
||||
|
||||
// Metadata
|
||||
float generation_time = 8; // Time to generate this token
|
||||
int32 queue_time = 9; // Time spent in queue
|
||||
}
|
||||
|
||||
message GenerateComplete {
|
||||
// Final output
|
||||
repeated int32 output_ids = 1;
|
||||
string output_text = 2;
|
||||
|
||||
// Finish reason
|
||||
enum FinishReason {
|
||||
// The model generated a stop sequence.
|
||||
STOP = 0;
|
||||
// The model reached the maximum generation length.
|
||||
LENGTH = 1;
|
||||
// The model generated an end-of-sequence (EOS) token.
|
||||
EOS_TOKEN = 2;
|
||||
// The model generated a user-provided stop string.
|
||||
STOP_STR = 3;
|
||||
// The request was aborted by the user or system.
|
||||
ABORT = 4;
|
||||
}
|
||||
FinishReason finish_reason = 3;
|
||||
|
||||
// All logprobs if requested
|
||||
repeated LogProbs all_logprobs = 11;
|
||||
|
||||
// All hidden states if requested
|
||||
repeated HiddenStates all_hidden_states = 7;
|
||||
|
||||
// Matched stop information (for stop sequences)
|
||||
oneof matched_stop {
|
||||
uint32 matched_token_id = 8;
|
||||
string matched_stop_str = 9;
|
||||
}
|
||||
|
||||
// Input logprobs if requested (for prompt tokens)
|
||||
LogProbs input_logprobs = 10;
|
||||
repeated HiddenStates all_hidden_states = 12;
|
||||
}
|
||||
|
||||
message GenerateError {
|
||||
@@ -224,11 +222,15 @@ message LogProbs {
|
||||
|
||||
// Top logprobs at each position
|
||||
repeated TopLogProbs top_logprobs = 3;
|
||||
|
||||
// Decoded text for tokens
|
||||
repeated string token_texts = 4;
|
||||
}
|
||||
|
||||
message TopLogProbs {
|
||||
repeated float values = 1;
|
||||
repeated int32 token_ids = 2;
|
||||
repeated string token_texts = 3;
|
||||
}
|
||||
|
||||
message HiddenStates {
|
||||
@@ -283,9 +285,10 @@ message EmbedComplete {
|
||||
|
||||
// Additional metadata
|
||||
int32 embedding_dim = 4;
|
||||
float generation_time = 5;
|
||||
|
||||
// For batch embeddings
|
||||
repeated Embedding batch_embeddings = 5;
|
||||
repeated Embedding batch_embeddings = 6;
|
||||
}
|
||||
|
||||
message Embedding {
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -3,6 +3,7 @@ import datetime
|
||||
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
||||
from google.protobuf import struct_pb2 as _struct_pb2
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
||||
@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
class SamplingParams(_message.Message):
|
||||
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
|
||||
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
|
||||
class LogitBiasEntry(_message.Message):
|
||||
__slots__ = ("key", "value")
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
|
||||
REGEX_FIELD_NUMBER: _ClassVar[int]
|
||||
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
|
||||
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
|
||||
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
||||
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||
N_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
|
||||
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
|
||||
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
|
||||
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
|
||||
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
||||
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||
temperature: float
|
||||
top_p: float
|
||||
@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
|
||||
regex: str
|
||||
json_schema: str
|
||||
ebnf_grammar: str
|
||||
structural_tag: str
|
||||
lora_path: str
|
||||
n: int
|
||||
token_healing: bool
|
||||
@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
|
||||
no_stop_trim: bool
|
||||
stream_interval: int
|
||||
logit_bias: _containers.ScalarMap[str, float]
|
||||
structural_tag: str
|
||||
custom_params: _struct_pb2.Struct
|
||||
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class DisaggregatedParams(_message.Message):
|
||||
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
|
||||
@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
|
||||
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class GenerateRequest(_message.Message):
|
||||
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
|
||||
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
|
||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
|
||||
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
||||
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
STREAM_FIELD_NUMBER: _ClassVar[int]
|
||||
request_id: str
|
||||
tokenized: TokenizedInput
|
||||
mm_inputs: MultimodalInputs
|
||||
@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
|
||||
lora_id: str
|
||||
data_parallel_rank: int
|
||||
dp_balance_id: int
|
||||
stream: bool
|
||||
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
|
||||
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class TokenizedInput(_message.Message):
|
||||
__slots__ = ("original_text", "input_ids")
|
||||
@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
|
||||
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class GenerateStreamChunk(_message.Message):
|
||||
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
|
||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
|
||||
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
token_id: int
|
||||
text: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
cached_tokens: int
|
||||
output_logprobs: LogProbs
|
||||
logprobs: LogProbs
|
||||
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
||||
input_logprobs: LogProbs
|
||||
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
|
||||
generation_time: float
|
||||
queue_time: int
|
||||
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class GenerateComplete(_message.Message):
|
||||
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
|
||||
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
|
||||
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
STOP: _ClassVar[GenerateComplete.FinishReason]
|
||||
LENGTH: _ClassVar[GenerateComplete.FinishReason]
|
||||
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
|
||||
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
|
||||
ABORT: _ClassVar[GenerateComplete.FinishReason]
|
||||
STOP: GenerateComplete.FinishReason
|
||||
LENGTH: GenerateComplete.FinishReason
|
||||
EOS_TOKEN: GenerateComplete.FinishReason
|
||||
STOP_STR: GenerateComplete.FinishReason
|
||||
ABORT: GenerateComplete.FinishReason
|
||||
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
|
||||
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
finish_reason: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
cached_tokens: int
|
||||
output_logprobs: LogProbs
|
||||
output_text: str
|
||||
finish_reason: GenerateComplete.FinishReason
|
||||
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
|
||||
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
||||
matched_token_id: int
|
||||
matched_stop_str: str
|
||||
input_logprobs: LogProbs
|
||||
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
|
||||
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
|
||||
|
||||
class GenerateError(_message.Message):
|
||||
__slots__ = ("message", "http_status_code", "details")
|
||||
@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
|
||||
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class LogProbs(_message.Message):
|
||||
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
|
||||
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
|
||||
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
|
||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
||||
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
|
||||
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class TopLogProbs(_message.Message):
|
||||
__slots__ = ("values", "token_ids")
|
||||
__slots__ = ("values", "token_ids", "token_texts")
|
||||
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||
values: _containers.RepeatedScalarFieldContainer[float]
|
||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
|
||||
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class HiddenStates(_message.Message):
|
||||
__slots__ = ("values", "layer", "position")
|
||||
@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
|
||||
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class EmbedComplete(_message.Message):
|
||||
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
|
||||
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
|
||||
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
|
||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
|
||||
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
|
||||
embedding: _containers.RepeatedScalarFieldContainer[float]
|
||||
prompt_tokens: int
|
||||
cached_tokens: int
|
||||
embedding_dim: int
|
||||
generation_time: float
|
||||
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
|
||||
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
|
||||
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
|
||||
|
||||
class Embedding(_message.Message):
|
||||
__slots__ = ("values", "index")
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
# This file is auto-generated. Do not edit manually.
|
||||
# Regenerate with: python compile_proto.py
|
||||
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
return config
|
||||
|
||||
|
||||
def _load_deepseek_v32_model(
|
||||
model_path: str,
|
||||
trust_remote_code: bool = False,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# first get the local path
|
||||
local_path = download_from_hf(model_path)
|
||||
# then load the config file in json
|
||||
config_file = os.path.join(local_path, "config.json")
|
||||
if not os.path.exists(config_file):
|
||||
raise RuntimeError(f"Can't find config file in {local_path}.")
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
config_json = json.load(f)
|
||||
|
||||
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
|
||||
config_json["model_type"] = "deepseek_v3"
|
||||
|
||||
tmp_path = os.path.join(local_path, "_tmp_config_folder")
|
||||
os.makedirs(tmp_path, exist_ok=True)
|
||||
|
||||
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
|
||||
with open(unique_path, "w") as f:
|
||||
json.dump(config_json, f)
|
||||
|
||||
return AutoConfig.from_pretrained(
|
||||
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
|
||||
|
||||
@lru_cache_frozenset(maxsize=32)
|
||||
def get_config(
|
||||
model: str,
|
||||
@@ -140,9 +171,17 @@ def get_config(
|
||||
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||
model = client.get_local_dir()
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
except ValueError as e:
|
||||
if not "deepseek_v32" in str(e):
|
||||
raise e
|
||||
config = _load_deepseek_v32_model(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
config.architectures is not None
|
||||
and config.architectures[0] == "Phi4MMForCausalLM"
|
||||
|
||||
@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
|
||||
assert len(k.shape) == 3
|
||||
assert len(v.shape) == 3
|
||||
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
if kv_indices.shape[0] == 0:
|
||||
o = flash_attn_varlen_func(
|
||||
q,
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import custom_ops
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
@@ -36,6 +37,8 @@ class ForwardMetadata:
|
||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||
seq_lens_cpu_list: Optional[List[int]] = None
|
||||
seq_lens_list_cumsum: Optional[List[int]] = None
|
||||
seq_lens: Optional[torch.Tensor] = None
|
||||
actual_seq_lengths_q: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AscendAttnBackend(AttentionBackend):
|
||||
@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
|
||||
if self.use_mla:
|
||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.q_head_dim = (
|
||||
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
|
||||
)
|
||||
self.native_attn = TorchNativeAttnBackend(model_runner)
|
||||
self.graph_metadata = {}
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
|
||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||
|
||||
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.is_extend_in_batch:
|
||||
seq_lens_list_cumsum[-1] = (
|
||||
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
|
||||
) * tp_size
|
||||
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
||||
|
||||
self.graph_mode = False
|
||||
@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
|
||||
|
||||
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
||||
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
||||
metadata.seq_lens = seq_lens
|
||||
metadata.actual_seq_lengths_q = torch.tensor(
|
||||
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
|
||||
)
|
||||
|
||||
self.graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
|
||||
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
||||
metadata.block_tables[bs:, :].fill_(0)
|
||||
|
||||
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
self.graph_mode = True
|
||||
@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_sparse(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
# For multi_head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: torch.Tensor = None,
|
||||
):
|
||||
|
||||
is_prefill = forward_batch.forward_mode.is_extend()
|
||||
|
||||
if save_kv_cache:
|
||||
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
||||
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
q_nope, q_pe = q, q_rope
|
||||
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
block_table = self.forward_metadata.block_tables
|
||||
if is_prefill:
|
||||
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
else:
|
||||
if self.forward_metadata.actual_seq_lengths_q is None:
|
||||
actual_seq_qlen = (
|
||||
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
|
||||
)
|
||||
else:
|
||||
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
|
||||
if self.forward_metadata.seq_lens_cpu_int is None:
|
||||
actual_seq_lengths_kv = self.forward_metadata.seq_lens
|
||||
else:
|
||||
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
|
||||
|
||||
attn_out = torch.ops.custom.npu_sparse_flash_attention(
|
||||
query=q_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=layer.scaling,
|
||||
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
|
||||
block_table=block_table,
|
||||
sparse_block_size=1,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
return attn_out
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
# For multi_head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if topk_indices is not None:
|
||||
return self.forward_sparse(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
q_rope,
|
||||
k_rope,
|
||||
topk_indices,
|
||||
)
|
||||
if not self.use_mla:
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if is_mla_preprocess_enabled():
|
||||
# MLAPO does saving kv_cache
|
||||
save_kv_cache = False
|
||||
if topk_indices is not None:
|
||||
return self.forward_sparse(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
q_rope,
|
||||
k_rope,
|
||||
topk_indices,
|
||||
)
|
||||
|
||||
if self.graph_mode:
|
||||
return self.forward_decode_graph(
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ATTENTION_BACKENDS = {}
|
||||
|
||||
|
||||
@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
|
||||
return AscendAttnBackend(runner)
|
||||
|
||||
|
||||
@register_attention_backend("nsa")
|
||||
def create_nsa_backend(runner):
|
||||
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
|
||||
|
||||
return NativeSparseAttnBackend(runner)
|
||||
|
||||
|
||||
@register_attention_backend("triton")
|
||||
def create_triton_backend(runner):
|
||||
assert not runner.model_config.is_encoder_decoder, (
|
||||
@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
|
||||
return DualChunkFlashAttentionBackend(runner)
|
||||
|
||||
|
||||
def attn_backend_wrapper(runner, full_attn_backend):
|
||||
"""
|
||||
Wrapper for special models like hybrid GDN, so we don't
|
||||
need to change the code of the original attention backend.
|
||||
"""
|
||||
assert not (
|
||||
runner.is_hybrid_gdn and runner.use_mla_backend
|
||||
), "hybrid_gdn can only be used with non-MLA models."
|
||||
@register_attention_backend("hybrid_linear_attn")
|
||||
def create_hybrid_linear_attn_backend(runner):
|
||||
assert (
|
||||
runner.is_hybrid_gdn
|
||||
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
HybridLinearAttnBackend,
|
||||
MambaAttnBackend,
|
||||
)
|
||||
from sglang.srt.utils import is_blackwell, is_npu
|
||||
|
||||
# wrap for hybrid GDN models
|
||||
if runner.is_hybrid_gdn:
|
||||
from sglang.srt.utils import is_blackwell, is_npu
|
||||
if is_npu():
|
||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||
|
||||
if is_blackwell():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "triton"
|
||||
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
|
||||
if is_npu():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "ascend"
|
||||
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
||||
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
HybridLinearAttnBackend,
|
||||
MambaAttnBackend,
|
||||
full_attn_backend = AscendAttnBackend(runner)
|
||||
elif is_blackwell():
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
|
||||
full_attn_backend = TritonAttnBackend(runner)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
)
|
||||
|
||||
linear_attn_backend = MambaAttnBackend(runner)
|
||||
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
||||
return HybridLinearAttnBackend(
|
||||
full_attn_backend, linear_attn_backend, full_attn_layers
|
||||
)
|
||||
full_attn_backend = FlashAttentionBackend(runner)
|
||||
|
||||
return full_attn_backend
|
||||
linear_attn_backend = MambaAttnBackend(runner)
|
||||
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
||||
|
||||
return HybridLinearAttnBackend(
|
||||
full_attn_backend, linear_attn_backend, full_attn_layers
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
|
||||
def support_triton(self):
|
||||
"""Check if the current backend supports triton."""
|
||||
return True
|
||||
|
||||
def get_indexer_metadata(
|
||||
self,
|
||||
layer_id: int,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> Optional[BaseIndexerMetadata]:
|
||||
"""Get the indexer metadata. None means don't support indexer."""
|
||||
return None
|
||||
|
||||
@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale, v_descale = None, None
|
||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||
# has corresponding quantization method so that layer.k_scale is not None,
|
||||
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
|
||||
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
|
||||
if (
|
||||
self.kv_cache_dtype_str != "auto"
|
||||
and layer.head_dim <= 256
|
||||
and self.fa_impl_ver != 4
|
||||
):
|
||||
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
||||
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
||||
if layer.k_scale is not None:
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
|
||||
@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||
from sglang.srt.utils import (
|
||||
get_int_env_var,
|
||||
is_flashinfer_available,
|
||||
@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrappers = []
|
||||
@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: Optional[bool] = None,
|
||||
):
|
||||
@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: Optional[bool] = None,
|
||||
):
|
||||
@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: Optional[bool] = None,
|
||||
):
|
||||
@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: Optional[bool] = None,
|
||||
):
|
||||
@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
fixed_split_size: Optional[int] = None,
|
||||
@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
if use_ragged:
|
||||
@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
custom_mask = None
|
||||
else:
|
||||
assert isinstance(
|
||||
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
|
||||
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
|
||||
)
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
|
||||
@@ -5,13 +5,20 @@ Support attention backend for FlashMLA.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||
|
||||
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
||||
from sglang.srt.layers.attention.nsa.utils import (
|
||||
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
NSA_KV_CACHE_STORE_FP8,
|
||||
compute_nsa_seqlens,
|
||||
)
|
||||
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
@@ -74,10 +81,17 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.scaling = model_runner.model_config.scaling
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
||||
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
|
||||
self.nsa_index_topk = (
|
||||
get_nsa_index_topk(model_runner.model_config.hf_config)
|
||||
if self.use_nsa
|
||||
else None
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
@@ -100,10 +114,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.req_to_token.stride(0),
|
||||
max_seqlen_pad,
|
||||
)
|
||||
mla_metadata, num_splits = get_mla_metadata(
|
||||
forward_batch.seq_lens.to(torch.int32),
|
||||
self.num_q_heads,
|
||||
1,
|
||||
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
||||
seq_len_q=1,
|
||||
num_heads_q=self.num_q_heads,
|
||||
num_heads_k=1,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
self.forward_metadata = FlashMLADecodeMetadata(
|
||||
mla_metadata,
|
||||
@@ -130,10 +146,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.req_to_token.stride(0),
|
||||
max_seqlen_pad,
|
||||
)
|
||||
mla_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens.to(torch.int32),
|
||||
self.num_draft_tokens * self.num_q_heads,
|
||||
1,
|
||||
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||
cache_seqlens=seq_lens.to(torch.int32),
|
||||
seq_len_q=self.num_draft_tokens,
|
||||
num_heads_q=self.num_q_heads,
|
||||
num_heads_k=1,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
|
||||
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
|
||||
@@ -162,20 +180,28 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
cuda_graph_kv_indices = block_kv_indices
|
||||
|
||||
if self.num_draft_tokens:
|
||||
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
||||
torch.ones(
|
||||
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
|
||||
),
|
||||
self.num_draft_tokens * self.num_q_heads,
|
||||
1,
|
||||
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
|
||||
_get_mla_metadata_wrapped(
|
||||
cache_seqlens=torch.ones(
|
||||
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
|
||||
),
|
||||
seq_len_q=self.num_draft_tokens,
|
||||
num_heads_q=self.num_q_heads,
|
||||
num_heads_k=1,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
||||
torch.ones(
|
||||
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
|
||||
),
|
||||
self.num_q_heads,
|
||||
1,
|
||||
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
|
||||
_get_mla_metadata_wrapped(
|
||||
cache_seqlens=torch.ones(
|
||||
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
|
||||
),
|
||||
seq_len_q=1,
|
||||
num_heads_q=self.num_q_heads,
|
||||
num_heads_k=1,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
)
|
||||
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
||||
|
||||
@@ -201,10 +227,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.req_to_token.stride(0),
|
||||
self.cuda_graph_kv_indices.stride(0),
|
||||
)
|
||||
mla_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens.to(torch.int32),
|
||||
self.num_q_heads,
|
||||
1,
|
||||
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||
cache_seqlens=seq_lens.to(torch.int32),
|
||||
seq_len_q=1,
|
||||
num_heads_q=self.num_q_heads,
|
||||
num_heads_k=1,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||
@@ -226,10 +254,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.req_to_token.stride(0),
|
||||
self.cuda_graph_kv_indices.stride(0),
|
||||
)
|
||||
mla_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens.to(torch.int32),
|
||||
self.num_draft_tokens * self.num_q_heads,
|
||||
1,
|
||||
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||
cache_seqlens=seq_lens.to(torch.int32),
|
||||
seq_len_q=self.num_draft_tokens,
|
||||
num_heads_q=self.num_q_heads,
|
||||
num_heads_k=1,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||
@@ -275,10 +305,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.req_to_token.stride(0),
|
||||
self.cuda_graph_kv_indices.stride(0),
|
||||
)
|
||||
mla_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens.to(torch.int32),
|
||||
self.num_q_heads,
|
||||
1,
|
||||
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||
cache_seqlens=seq_lens.to(torch.int32),
|
||||
seq_len_q=1,
|
||||
num_heads_q=self.num_q_heads,
|
||||
num_heads_k=1,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||
@@ -300,10 +332,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
self.req_to_token.stride(0),
|
||||
self.cuda_graph_kv_indices.stride(0),
|
||||
)
|
||||
mla_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens.to(torch.int32),
|
||||
self.num_draft_tokens * self.num_q_heads,
|
||||
1,
|
||||
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||
cache_seqlens=seq_lens.to(torch.int32),
|
||||
seq_len_q=self.num_draft_tokens,
|
||||
num_heads_q=self.num_q_heads,
|
||||
num_heads_k=1,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||
@@ -335,6 +369,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
|
||||
@@ -349,13 +384,14 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
)
|
||||
bs = forward_batch.batch_size
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
k_cache = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
|
||||
|
||||
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||
if self.data_type == torch.float8_e4m3fn:
|
||||
if (not self.use_nsa) and self.data_type == torch.float8_e4m3fn:
|
||||
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=reshape_q_fp8,
|
||||
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
|
||||
k_cache=k_cache,
|
||||
block_table=self.forward_metadata.block_kv_indices[:bs],
|
||||
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
||||
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
|
||||
@@ -369,17 +405,49 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
else:
|
||||
block_table = self.forward_metadata.block_kv_indices[:bs]
|
||||
cache_seqlens = forward_batch.seq_lens.to(torch.int32)
|
||||
|
||||
extra_kwargs: Dict
|
||||
if self.use_nsa:
|
||||
assert topk_indices is not None
|
||||
extra_kwargs = dict(
|
||||
indices=_compute_indices_in_kvcache(
|
||||
block_table=block_table,
|
||||
topk_indices=topk_indices.to(torch.int32),
|
||||
page_size=self.page_size,
|
||||
),
|
||||
# doc says it is not used, but if pass in None then error
|
||||
block_table=block_table,
|
||||
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
)
|
||||
cache_seqlens = compute_nsa_seqlens(
|
||||
cache_seqlens, nsa_index_topk=self.nsa_index_topk
|
||||
)
|
||||
else:
|
||||
extra_kwargs = dict(
|
||||
block_table=block_table,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if (
|
||||
self.use_nsa
|
||||
and NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
|
||||
and not NSA_KV_CACHE_STORE_FP8
|
||||
):
|
||||
# inefficiently quantize the whole cache
|
||||
k_cache = quantize_k_cache(k_cache)
|
||||
|
||||
# todo: need check all causal True or False?
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=reshape_q,
|
||||
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
|
||||
block_table=self.forward_metadata.block_kv_indices[:bs],
|
||||
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
||||
k_cache=k_cache,
|
||||
cache_seqlens=cache_seqlens,
|
||||
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
|
||||
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
|
||||
num_splits=self.forward_metadata.num_splits,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
@@ -539,3 +607,52 @@ class FlashMLAMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, call_fn)
|
||||
|
||||
|
||||
def _get_mla_metadata_wrapped(
|
||||
*,
|
||||
cache_seqlens: torch.Tensor,
|
||||
seq_len_q: int,
|
||||
num_heads_q: int,
|
||||
num_heads_k: int,
|
||||
nsa_index_topk: Optional[int],
|
||||
):
|
||||
if nsa_index_topk is not None:
|
||||
assert nsa_index_topk is not None
|
||||
return get_mla_metadata(
|
||||
cache_seqlens=cache_seqlens,
|
||||
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
|
||||
# but the name looks like need seq_len_q?
|
||||
num_q_tokens_per_head_k=seq_len_q * num_heads_q // num_heads_k,
|
||||
num_heads_k=num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
topk=nsa_index_topk,
|
||||
)
|
||||
else:
|
||||
assert nsa_index_topk is None
|
||||
return get_mla_metadata(
|
||||
cache_seqlens=cache_seqlens,
|
||||
num_heads_per_head_k=seq_len_q * num_heads_q // num_heads_k,
|
||||
num_heads_k=num_heads_k,
|
||||
)
|
||||
|
||||
|
||||
# TODO speedup
|
||||
def _compute_indices_in_kvcache(block_table, topk_indices, page_size):
|
||||
topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
|
||||
|
||||
idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
|
||||
1
|
||||
)
|
||||
block_idx = block_table[idx0, topk_indices_safe // page_size]
|
||||
offset = topk_indices_safe % page_size
|
||||
indices_in_kvcache = block_idx * page_size + offset
|
||||
|
||||
# the kernel requires invalid entry to be -1
|
||||
assert indices_in_kvcache.shape == topk_indices.shape
|
||||
indices_in_kvcache[topk_indices == -1] = -1
|
||||
|
||||
# return: (batch_size, seqlen_q_ori, topk)
|
||||
indices_in_kvcache = indices_in_kvcache[:, None, :]
|
||||
return indices_in_kvcache
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
|
||||
return backend.forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
def get_indexer_metadata(
|
||||
self, layer_id: int, forward_batch: ForwardBatch
|
||||
) -> Optional[BaseIndexerMetadata]:
|
||||
backend = self._select_backend(forward_batch.forward_mode)
|
||||
return backend.get_indexer_metadata(layer_id, forward_batch)
|
||||
|
||||
121
python/sglang/srt/layers/attention/native_mla.py
Normal file
121
python/sglang/srt/layers/attention/native_mla.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
|
||||
def cdiv(x: int, y: int):
|
||||
return (x+y-1) // y
|
||||
|
||||
def native_mla_sparse_fwd(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
||||
s_q, _, d_qk = q.size()
|
||||
s_kv = kv.size(0)
|
||||
topk = indices.size(-1)
|
||||
|
||||
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
|
||||
|
||||
indices = indices[:, 0, :] # [s_q, topk]
|
||||
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
|
||||
qs = q.float() # [s_q, h_q, d_qk]
|
||||
kvs = kv[ :, 0, :].float() # [s_kv, d_qk]
|
||||
|
||||
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(s_q, topk, d_qk) # [s_q, topk, d_qk]
|
||||
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
|
||||
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
|
||||
attn_score *= sm_scale * math.log2(math.e)
|
||||
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
|
||||
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
|
||||
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
|
||||
result = attn_score @ kvs[:, :, :d_v]
|
||||
return (max_logits, lse, result)
|
||||
|
||||
|
||||
|
||||
def native_mla_with_kvcache(
|
||||
q: torch.Tensor, # [batch_size, s_q, h_q, d]
|
||||
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
|
||||
block_table: torch.Tensor, # [batch_size, ?]
|
||||
cache_seqlens: torch.Tensor, # [batch_size]
|
||||
dv: int,
|
||||
is_causal: bool,
|
||||
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
A reference implementation in PyTorch
|
||||
"""
|
||||
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
|
||||
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
|
||||
for i in range(s_q):
|
||||
cur_indices = indices[i]
|
||||
valid_indices = cur_indices[cur_indices != -1]
|
||||
mask[i, valid_indices] = True
|
||||
return mask
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
batch_idx: int,
|
||||
query: torch.Tensor, # [h_q, s_q, d]
|
||||
kv: torch.Tensor, # [h_kv, s_k, d]
|
||||
dv: int,
|
||||
is_causal,
|
||||
indices: Optional[torch.Tensor], # [s_q, topk]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
h_q = query.size(0)
|
||||
h_kv = kv.size(0)
|
||||
s_q = query.shape[-2]
|
||||
s_k = kv.shape[-2]
|
||||
query = query.float()
|
||||
kv = kv.float()
|
||||
if h_kv != 1:
|
||||
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
|
||||
kv[kv != kv] = 0.0
|
||||
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
|
||||
if (is_causal and query.size(1) > 1) or indices is not None:
|
||||
mask = torch.ones(s_q, s_k, dtype=torch.bool)
|
||||
if is_causal:
|
||||
assert indices is None
|
||||
mask = mask.tril(diagonal=s_k - s_q)
|
||||
if indices is not None:
|
||||
mask &= get_topk_attn_mask(s_q, s_k, indices)
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
|
||||
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
|
||||
attn_weight += attn_bias.to(q.dtype)
|
||||
attn_weight /= math.sqrt(query.size(-1))
|
||||
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
|
||||
# Correct for q tokens which has no attendable k
|
||||
lonely_q_mask = (lse == float("-inf"))
|
||||
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
|
||||
lse[lonely_q_mask] = float("+inf")
|
||||
|
||||
return output, lse
|
||||
|
||||
b, s_q, h_q, d = q.size()
|
||||
block_size = blocked_k.size(1)
|
||||
h_kv = blocked_k.size(2)
|
||||
cache_seqlens_cpu = cache_seqlens.cpu()
|
||||
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
cur_len = cache_seqlens_cpu[i].item()
|
||||
cur_num_blocks = cdiv(cur_len, block_size)
|
||||
cur_block_indices = block_table[i][0: cur_num_blocks]
|
||||
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
|
||||
cur_out, cur_lse = scaled_dot_product_attention(
|
||||
i,
|
||||
q[i].transpose(0, 1),
|
||||
cur_kv.transpose(0, 1),
|
||||
dv,
|
||||
is_causal,
|
||||
indices[i] if indices is not None else None
|
||||
)
|
||||
out_ref[i] = cur_out.transpose(0, 1)
|
||||
lse_ref[i] = cur_lse
|
||||
out_ref = out_ref.to(torch.bfloat16)
|
||||
return out_ref, lse_ref
|
||||
@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
||||
self.rotary_emb = rotary_emb
|
||||
self.layer_id = layer_id
|
||||
self.has_preprocess_weights = False
|
||||
self.dtype = None
|
||||
|
||||
self.q_lora_rank = self.q_b_proj.input_size # 1536
|
||||
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
|
||||
self.num_local_heads = num_local_heads # tp
|
||||
self.qk_nope_head_dim = qk_nope_head_dim # 128
|
||||
self.qk_rope_head_dim = qk_rope_head_dim # 64
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
|
||||
def preprocess_weights(self, hidden_states):
|
||||
self.dummy = torch.empty(
|
||||
@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
||||
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
|
||||
return k_cache, v_cache, slot_mapping
|
||||
|
||||
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||
def forward_absorb_prepare_npu_rms_norm_cache(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
):
|
||||
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
|
||||
self.dtype = hidden_states.dtype
|
||||
self.cos, self.sin = self.get_sin_cos(positions)
|
||||
self.kvCache, self.kvCacheRope, self.slotmapping = (
|
||||
self.get_kv_cache_and_cache_idx(forward_batch)
|
||||
)
|
||||
|
||||
if not self.has_preprocess_weights:
|
||||
self.has_preprocess_weights = True
|
||||
|
||||
cos, sin = self.cos, self.sin
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
|
||||
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q = self.q_a_layernorm(q_lowrank)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
) # b*s,n,d
|
||||
|
||||
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
|
||||
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
|
||||
|
||||
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
|
||||
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
|
||||
|
||||
latent_cache = latent_cache.view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
) # (B*S,N,1,D)
|
||||
|
||||
cache_mode = "PA_BNSD"
|
||||
self.kvCache = self.kvCache.view(
|
||||
-1,
|
||||
forward_batch.attn_backend.page_size,
|
||||
1,
|
||||
forward_batch.attn_backend.kv_lora_rank,
|
||||
)
|
||||
self.kvCacheRope = self.kvCacheRope.view(
|
||||
-1,
|
||||
forward_batch.attn_backend.page_size,
|
||||
1,
|
||||
forward_batch.attn_backend.qk_rope_head_dim,
|
||||
)
|
||||
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
latent_cache,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
self.slotmapping.to(torch.int64),
|
||||
self.kvCacheRope,
|
||||
self.kvCache,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode=cache_mode,
|
||||
)
|
||||
|
||||
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
|
||||
|
||||
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||
input_dtype = hidden_states.dtype
|
||||
if not self.has_preprocess_weights:
|
||||
self.preprocess_weights(hidden_states)
|
||||
@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
||||
zero_allocator,
|
||||
positions,
|
||||
)
|
||||
|
||||
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||
_is_w8a8 = (
|
||||
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
|
||||
and self.qkv_a_proj.quant_method.quantization_config.get_name()
|
||||
== "w8a8_int8"
|
||||
)
|
||||
if _is_w8a8:
|
||||
return self.forward_mlapo(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb_prepare_npu_rms_norm_cache(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
|
||||
3
python/sglang/srt/layers/attention/nsa/cuda/__init__.py
Normal file
3
python/sglang/srt/layers/attention/nsa/cuda/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .topk import fast_topk, fast_topk_transform
|
||||
|
||||
__all__ = ["fast_topk", "fast_topk_transform"]
|
||||
505
python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
Normal file
505
python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
Normal file
@@ -0,0 +1,505 @@
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int TopK = 2048;
|
||||
constexpr int kThreadsPerBlock = 1024;
|
||||
constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
|
||||
|
||||
struct FastTopKParams {
|
||||
const float *__restrict__ input; // [B, input_stride]
|
||||
int32_t *__restrict__ indices; // [B, TopK]
|
||||
int32_t *__restrict__ lengths; // [B]
|
||||
int64_t input_stride;
|
||||
bool use_tilelang;
|
||||
};
|
||||
|
||||
// when length <= TopK, we can directly write the indices
|
||||
__device__ void naive_topk_cuda(const float *__restrict__ score,
|
||||
int32_t *__restrict__ indice, int32_t length) {
|
||||
const auto tid = threadIdx.x;
|
||||
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
|
||||
indice[i] = (i < length) ? i : -1;
|
||||
}
|
||||
}
|
||||
|
||||
// keep the first `length` entries, set others to -1
|
||||
__device__ void
|
||||
naive_topk_transform(const float *__restrict__ score, int32_t length,
|
||||
int32_t *__restrict__ dst_page_table,
|
||||
const int32_t *__restrict__ src_page_table) {
|
||||
const auto tid = threadIdx.x;
|
||||
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
|
||||
dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t convert_to_uint8(float x) {
|
||||
__half h = __float2half_rn(x);
|
||||
uint16_t bits = __half_as_ushort(h);
|
||||
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits & 0xFFFF)
|
||||
: static_cast<uint16_t>(bits | 0x8000);
|
||||
return static_cast<uint8_t>(key >> 8);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t convert_to_uint32(float x) {
|
||||
uint32_t bits = __float_as_uint(x);
|
||||
return (bits & 0x80000000u) ? (~bits & 0xFFFFFFFFu) : (bits | 0x80000000u);
|
||||
}
|
||||
|
||||
template <bool Is_Epilogue = false, typename Indexer, typename Loader,
|
||||
int LENGTH, int MAX_REMAIN>
|
||||
__device__ __forceinline__ auto
|
||||
radix_topk(Indexer indexer, Loader loader, uint32_t length, int topk,
|
||||
int *__restrict__ index, int &__restrict__ s_counter,
|
||||
int (&__restrict__ s_histogram)[LENGTH],
|
||||
int &__restrict__ s_remain_cnt,
|
||||
int (&__restrict__ s_remain_idx)[MAX_REMAIN]) -> int {
|
||||
constexpr auto RADIX = LENGTH - 1;
|
||||
static_assert(RADIX > 1 && (RADIX & (RADIX - 1)) == 0,
|
||||
"RADIX must be power of 2");
|
||||
static_assert(RADIX <= kThreadsPerBlock);
|
||||
__shared__ uint32_t s_threshold_bin_id;
|
||||
|
||||
const auto tx = threadIdx.x;
|
||||
if (tx < RADIX + 1)
|
||||
s_histogram[tx] = 0;
|
||||
__syncthreads();
|
||||
|
||||
/// NOTE: Use uint32_t as the index
|
||||
for (auto i = tx; i < length; i += kThreadsPerBlock) {
|
||||
const auto idx = indexer(i);
|
||||
const auto bin = loader(idx);
|
||||
::atomicAdd(&s_histogram[bin], 1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// cumsum (descending)
|
||||
if (tx == 0) {
|
||||
s_histogram[RADIX] = 0;
|
||||
s_remain_cnt = 0;
|
||||
for (int i = RADIX - 2; i >= 0; --i) {
|
||||
s_histogram[i] += s_histogram[i + 1];
|
||||
}
|
||||
// threshold bin
|
||||
for (int i = 0; i < RADIX; i++) {
|
||||
if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
|
||||
s_threshold_bin_id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto threshold_bin = s_threshold_bin_id;
|
||||
const auto new_topk = topk - s_histogram[threshold_bin + 1];
|
||||
|
||||
for (auto i = tx; i < length; i += kThreadsPerBlock) {
|
||||
const auto idx = indexer(i);
|
||||
const auto bin_id = static_cast<uint32_t>(loader(idx));
|
||||
if (bin_id > threshold_bin) {
|
||||
index[::atomicAdd(&s_counter, 1)] = idx;
|
||||
} else if (bin_id == threshold_bin && new_topk > 0) {
|
||||
if constexpr (Is_Epilogue) {
|
||||
index[::atomicAdd(&s_counter, 1)] = idx;
|
||||
} else {
|
||||
if (const auto cnt = ::atomicAdd(&s_remain_cnt, 1);
|
||||
C10_LIKELY(cnt < MAX_REMAIN)) {
|
||||
s_remain_idx[cnt] = idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
return new_topk;
|
||||
}
|
||||
|
||||
__device__ void fast_topk_cuda(const float *__restrict__ input,
|
||||
int *__restrict__ index, int length,
|
||||
int topk = TopK) {
|
||||
constexpr auto RADIX = 256;
|
||||
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
|
||||
|
||||
__shared__ int s_histogram[RADIX + 1];
|
||||
__shared__ int s_num_input[2];
|
||||
__shared__ int s_counter;
|
||||
|
||||
// allocate for two rounds
|
||||
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
|
||||
s_counter = 0;
|
||||
|
||||
// collect candidates
|
||||
const auto indexer = [](int idx) { return idx; };
|
||||
const auto loader = [&input](int idx) {
|
||||
return convert_to_uint8(input[idx]);
|
||||
};
|
||||
int new_topk = radix_topk(indexer, loader, length, topk, index, s_counter,
|
||||
s_histogram, s_num_input[0], s_input_idx[0]);
|
||||
if (new_topk <= 0)
|
||||
return;
|
||||
|
||||
// round 0
|
||||
const auto indexer_0 = [](int idx) { return s_input_idx[0][idx]; };
|
||||
const auto loader_0 = [&input](int idx) {
|
||||
return (convert_to_uint32(input[idx]) >> 24) & 0xFF;
|
||||
};
|
||||
new_topk = radix_topk(indexer_0, loader_0, s_num_input[0], new_topk, index,
|
||||
s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
|
||||
if (new_topk <= 0)
|
||||
return;
|
||||
|
||||
// round 1
|
||||
const auto indexer_1 = [](int idx) { return s_input_idx[1][idx]; };
|
||||
const auto loader_1 = [&input](int idx) {
|
||||
return (convert_to_uint32(input[idx]) >> 16) & 0xFF;
|
||||
};
|
||||
new_topk = radix_topk(indexer_1, loader_1, s_num_input[1], new_topk, index,
|
||||
s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
|
||||
if (new_topk <= 0)
|
||||
return;
|
||||
|
||||
// round 2
|
||||
const auto loader_2 = [&input](int idx) {
|
||||
return (convert_to_uint32(input[idx]) >> 8) & 0xFF;
|
||||
};
|
||||
new_topk = radix_topk(indexer_0, loader_2, s_num_input[0], new_topk, index,
|
||||
s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
|
||||
if (new_topk <= 0)
|
||||
return;
|
||||
|
||||
// round 3
|
||||
const auto loader_3 = [&input](int idx) {
|
||||
return convert_to_uint32(input[idx]) & 0xFF;
|
||||
};
|
||||
// epilogue
|
||||
radix_topk<true>(indexer_1, loader_3, s_num_input[1], new_topk, index,
|
||||
s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
|
||||
}
|
||||
|
||||
__device__ void fast_topk_cuda_tl(const float *__restrict__ input,
|
||||
int *__restrict__ index, int length,
|
||||
int topk = TopK) {
|
||||
constexpr auto BLOCK_SIZE = 1024;
|
||||
constexpr auto RADIX = 256;
|
||||
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
|
||||
|
||||
__shared__ int s_threshold_bin_id;
|
||||
__shared__ int s_histogram[RADIX + 1];
|
||||
__shared__ int s_num_input[2];
|
||||
__shared__ int s_counter;
|
||||
|
||||
// allocate for two rounds
|
||||
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
|
||||
|
||||
int tx = threadIdx.x;
|
||||
|
||||
// stage 1: 8bit coarse histogram
|
||||
if (tx < RADIX + 1)
|
||||
s_histogram[tx] = 0;
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
||||
const auto bin = convert_to_uint8(input[idx]);
|
||||
::atomicAdd(&s_histogram[bin], 1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// cumsum (descending)
|
||||
if (tx == 0) {
|
||||
for (int i = RADIX - 2; i >= 0; --i) {
|
||||
s_histogram[i] += s_histogram[i + 1];
|
||||
}
|
||||
// threshold bin
|
||||
for (int i = 0; i < RADIX; i++) {
|
||||
if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
|
||||
s_threshold_bin_id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
s_num_input[0] = 0;
|
||||
s_counter = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int threshold_bin = s_threshold_bin_id;
|
||||
int new_topk = topk - s_histogram[threshold_bin + 1];
|
||||
|
||||
// collect candidates
|
||||
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
||||
const auto bin_id = static_cast<int>(convert_to_uint8(input[idx]));
|
||||
if (bin_id > threshold_bin) {
|
||||
int pos = ::atomicAdd(&s_counter, 1);
|
||||
index[pos] = idx;
|
||||
} else if (bin_id == threshold_bin && new_topk > 0) {
|
||||
int pos = ::atomicAdd(&s_num_input[0], 1);
|
||||
if (pos < SMEM_INPUT_SIZE) {
|
||||
[[likely]] s_input_idx[0][pos] = idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// stage 2: refine with 8bit radix passes
|
||||
#pragma unroll 4
|
||||
for (int round = 0; round < 4; ++round) {
|
||||
if (new_topk <= 0)
|
||||
break;
|
||||
int r_idx = round % 2;
|
||||
|
||||
// reset
|
||||
if (tx < RADIX + 1)
|
||||
s_histogram[tx] = 0;
|
||||
__syncthreads();
|
||||
|
||||
int num_input = s_num_input[r_idx];
|
||||
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
|
||||
int idx = s_input_idx[r_idx][i];
|
||||
uint32_t bin32 =
|
||||
(convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
|
||||
::atomicAdd(&s_histogram[bin32], 1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tx == 0) {
|
||||
for (int i = RADIX - 2; i >= 0; --i)
|
||||
s_histogram[i] += s_histogram[i + 1];
|
||||
for (int i = 0; i < RADIX; i++) {
|
||||
if (s_histogram[i] >= new_topk && s_histogram[i + 1] < new_topk) {
|
||||
s_threshold_bin_id = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
s_num_input[r_idx ^ 1] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
new_topk -= s_histogram[s_threshold_bin_id + 1];
|
||||
int threshold_bin = s_threshold_bin_id;
|
||||
|
||||
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
|
||||
int idx = s_input_idx[r_idx][i];
|
||||
uint32_t bin32 =
|
||||
(convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
|
||||
if (bin32 > threshold_bin) {
|
||||
int pos = ::atomicAdd(&s_counter, 1);
|
||||
index[pos] = idx;
|
||||
} else if (bin32 == threshold_bin && new_topk > 0) {
|
||||
if (round == 3) {
|
||||
int pos = ::atomicAdd(&s_counter, 1);
|
||||
index[pos] = idx;
|
||||
} else {
|
||||
int pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
|
||||
if (pos < SMEM_INPUT_SIZE)
|
||||
s_input_idx[r_idx ^ 1][pos] = idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void topk_kernel(const FastTopKParams params) {
|
||||
const auto &[input, indices, lengths, input_stride, use_tilelang] = params;
|
||||
const auto bid = blockIdx.x;
|
||||
const auto length = *(lengths + bid);
|
||||
const auto indice = indices + bid * TopK;
|
||||
const auto score = input + bid * input_stride;
|
||||
if (length <= TopK) {
|
||||
return naive_topk_cuda(score, indice, length);
|
||||
} else {
|
||||
if (use_tilelang) {
|
||||
return fast_topk_cuda_tl(score, indice, length);
|
||||
} else {
|
||||
return fast_topk_cuda(score, indice, length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void topk_kernel_transform_decode( // decode
|
||||
const FastTopKParams params, int32_t *__restrict__ dst_page_table,
|
||||
const int32_t *__restrict__ src_page_table, const int64_t src_stride) {
|
||||
const auto &[input, _, lengths, input_stride, use_tilelang] = params;
|
||||
const auto bid = blockIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const auto length = *(lengths + bid);
|
||||
const auto src_page_entry = src_page_table + bid * src_stride;
|
||||
const auto dst_page_entry = dst_page_table + bid * TopK;
|
||||
const auto score = input + bid * input_stride;
|
||||
if (length <= TopK) {
|
||||
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
|
||||
} else {
|
||||
__shared__ int s_indices[TopK];
|
||||
if (use_tilelang) {
|
||||
fast_topk_cuda_tl(score, s_indices, length);
|
||||
} else {
|
||||
fast_topk_cuda(score, s_indices, length);
|
||||
}
|
||||
// copy src[s_indices] to dst, we manually unroll here
|
||||
static_assert(TopK % kThreadsPerBlock == 0);
|
||||
static_assert(TopK / kThreadsPerBlock == 2);
|
||||
const auto idx_0 = tid;
|
||||
const auto pos_0 = s_indices[idx_0];
|
||||
dst_page_entry[idx_0] = src_page_entry[pos_0];
|
||||
const auto idx_1 = tid + kThreadsPerBlock;
|
||||
const auto pos_1 = s_indices[idx_1];
|
||||
dst_page_entry[idx_1] = src_page_entry[pos_1];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void topk_kernel_transform_prefill( // prefill
|
||||
const FastTopKParams params, int32_t *__restrict__ dst_page_table,
|
||||
const int32_t *__restrict__ src_page_table, const int64_t src_stride,
|
||||
const int32_t *__restrict__ cu_seqlens, const int64_t prefill_bs) {
|
||||
const auto &[input, _, lengths, input_stride, use_tilelang] = params;
|
||||
const auto bid = blockIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const auto length = *(lengths + bid);
|
||||
const auto dst_page_entry = dst_page_table + bid * TopK;
|
||||
const auto score = input + bid * input_stride;
|
||||
|
||||
/// NOTE: prefill bs is usually small, we can just use a simple loop here
|
||||
/// We ensure that last cu_seqlens is equal to number of blocks launched
|
||||
assert(gridDim.x == cu_seqlens[prefill_bs] &&
|
||||
"Invalid cu_seqlens in topk-transform-prefill");
|
||||
__shared__ const int32_t *s_src_page_entry;
|
||||
if (tid == 0) {
|
||||
for (int64_t offset = 0; offset < prefill_bs; ++offset) {
|
||||
if (bid < cu_seqlens[offset + 1]) {
|
||||
s_src_page_entry = src_page_table + offset * src_stride;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
const auto src_page_entry = s_src_page_entry;
|
||||
|
||||
if (length <= TopK) {
|
||||
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
|
||||
} else {
|
||||
__shared__ int s_indices[TopK];
|
||||
if (use_tilelang) {
|
||||
fast_topk_cuda_tl(score, s_indices, length);
|
||||
} else {
|
||||
fast_topk_cuda(score, s_indices, length);
|
||||
}
|
||||
// copy src[s_indices] to dst, we manually unroll here
|
||||
static_assert(TopK % kThreadsPerBlock == 0);
|
||||
static_assert(TopK / kThreadsPerBlock == 2);
|
||||
const auto idx_0 = tid;
|
||||
const auto pos_0 = s_indices[idx_0];
|
||||
dst_page_entry[idx_0] = src_page_entry[pos_0];
|
||||
const auto idx_1 = tid + kThreadsPerBlock;
|
||||
const auto pos_1 = s_indices[idx_1];
|
||||
dst_page_entry[idx_1] = src_page_entry[pos_1];
|
||||
}
|
||||
}
|
||||
|
||||
auto get_params(at::Tensor score, at::Tensor lengths, bool use_tilelang,
|
||||
std::optional<at::Tensor> indices_opt = std::nullopt)
|
||||
-> FastTopKParams {
|
||||
const auto B = score.size(0);
|
||||
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
|
||||
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
|
||||
TORCH_CHECK(lengths.size(0) == B);
|
||||
int32_t *indices_data_ptr = nullptr;
|
||||
if (indices_opt.has_value()) {
|
||||
const auto &indices = indices_opt.value();
|
||||
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
|
||||
TORCH_CHECK(indices.size(0) == B);
|
||||
TORCH_CHECK(indices.size(1) == TopK);
|
||||
indices_data_ptr = indices.data_ptr<int32_t>();
|
||||
}
|
||||
|
||||
return FastTopKParams{
|
||||
.input = score.data_ptr<float>(),
|
||||
.indices = indices_data_ptr,
|
||||
.lengths = lengths.data_ptr<int32_t>(),
|
||||
.input_stride = score.stride(0),
|
||||
.use_tilelang = use_tilelang,
|
||||
};
|
||||
}
|
||||
|
||||
template <auto *f, size_t max_dynamic_smem>
|
||||
auto setup_kernel_smem_once() -> void {
|
||||
[[maybe_unused]]
|
||||
static const auto result = [] {
|
||||
return ::cudaFuncSetAttribute(
|
||||
f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem);
|
||||
}();
|
||||
TORCH_CHECK(result == cudaSuccess,
|
||||
"set_up_kernel_once failed:", ::cudaGetErrorString(result));
|
||||
}
|
||||
|
||||
auto fast_topk_interface(at::Tensor score, at::Tensor indices,
|
||||
at::Tensor lengths, bool use_tilelang) -> void {
|
||||
const auto params = get_params(score, lengths, use_tilelang, indices);
|
||||
const auto B = score.size(0);
|
||||
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const auto grid = dim3{static_cast<uint32_t>(B)};
|
||||
const auto block = dim3{kThreadsPerBlock};
|
||||
setup_kernel_smem_once<topk_kernel, kSmem>();
|
||||
topk_kernel<<<grid, block, kSmem, stream>>>(params);
|
||||
const auto result = cudaGetLastError();
|
||||
TORCH_CHECK(result == cudaSuccess,
|
||||
"topk kernel failed:", ::cudaGetErrorString(result));
|
||||
}
|
||||
|
||||
auto fast_topk_transform_interface(at::Tensor score, at::Tensor lengths,
|
||||
at::Tensor dst_page_table,
|
||||
at::Tensor src_page_table,
|
||||
at::Tensor cu_seqlens,
|
||||
bool use_tilelang) -> void {
|
||||
const auto params = get_params(score, lengths, use_tilelang);
|
||||
const auto B = score.size(0);
|
||||
TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
|
||||
TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
|
||||
TORCH_CHECK(cu_seqlens.dim() == 1 && cu_seqlens.is_contiguous());
|
||||
const auto prefill_bs = cu_seqlens.size(0) - 1;
|
||||
TORCH_CHECK(dst_page_table.size(0) == B);
|
||||
TORCH_CHECK(dst_page_table.size(1) == TopK);
|
||||
TORCH_CHECK(src_page_table.size(0) == prefill_bs);
|
||||
TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
|
||||
|
||||
// launch kernel
|
||||
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const auto grid = dim3{static_cast<uint32_t>(B)};
|
||||
const auto block = dim3{kThreadsPerBlock};
|
||||
const auto src_stride = src_page_table.stride(0);
|
||||
|
||||
// dispatch to decode or prefill
|
||||
const auto is_decode = (prefill_bs == B);
|
||||
if (is_decode) {
|
||||
setup_kernel_smem_once<topk_kernel_transform_decode, kSmem>();
|
||||
topk_kernel_transform_decode<<<grid, block, kSmem, stream>>>(
|
||||
params, dst_page_table.data_ptr<int32_t>(),
|
||||
src_page_table.data_ptr<int32_t>(), src_stride);
|
||||
} else {
|
||||
setup_kernel_smem_once<topk_kernel_transform_prefill, kSmem>();
|
||||
topk_kernel_transform_prefill<<<grid, block, kSmem, stream>>>(
|
||||
params, dst_page_table.data_ptr<int32_t>(),
|
||||
src_page_table.data_ptr<int32_t>(), src_stride,
|
||||
cu_seqlens.data_ptr<int32_t>(), prefill_bs);
|
||||
}
|
||||
|
||||
const auto result = cudaGetLastError();
|
||||
TORCH_CHECK(result == cudaSuccess,
|
||||
"topk kernel failed:", ::cudaGetErrorString(result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(topk_kernel, m) {
|
||||
m.def("fast_topk", &fast_topk_interface);
|
||||
m.def("fast_topk_transform", &fast_topk_transform_interface);
|
||||
}
|
||||
39
python/sglang/srt/layers/attention/nsa/cuda/topk.py
Normal file
39
python/sglang/srt/layers/attention/nsa/cuda/topk.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import load_kernel_module
|
||||
|
||||
|
||||
def _load_topk_module() -> Any:
|
||||
"""
|
||||
Load the index manipulation module.
|
||||
"""
|
||||
return load_kernel_module("topk.cu", "topk_kernel")
|
||||
|
||||
|
||||
# TODO(dark): configure out why my cuda impl is a little slower....
|
||||
# I believe it has something to do with unrolling loops (?)
|
||||
_USE_TL = True
|
||||
|
||||
|
||||
def fast_topk(
|
||||
score: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return _load_topk_module().fast_topk(score, indices, lengths, _USE_TL)
|
||||
|
||||
|
||||
def fast_topk_transform(
|
||||
score: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
dst_page_table: torch.Tensor,
|
||||
src_page_table: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return _load_topk_module().fast_topk_transform(
|
||||
score, lengths, dst_page_table, src_page_table, cu_seqlens, _USE_TL
|
||||
)
|
||||
44
python/sglang/srt/layers/attention/nsa/cuda/utils.py
Normal file
44
python/sglang/srt/layers/attention/nsa/cuda/utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _prepare_for_load() -> str:
|
||||
import os
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, module="torch.utils.cpp_extension"
|
||||
)
|
||||
return os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def load_kernel_module(
|
||||
path: str | Iterable[str],
|
||||
name: str,
|
||||
*,
|
||||
build: str = "build",
|
||||
cflags: Iterable[str] | None = None,
|
||||
cuda_flags: Iterable[str] | None = None,
|
||||
ldflags: Iterable[str] | None = None,
|
||||
) -> Any:
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
if isinstance(path, str):
|
||||
path = (path,)
|
||||
|
||||
abs_path = _prepare_for_load()
|
||||
build_dir = f"{abs_path}/{build}"
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
return load(
|
||||
name=name,
|
||||
sources=[f"{abs_path}/csrc/{p}" for p in path],
|
||||
extra_cflags=list(cflags or []) or ["-O3", "-std=c++17"],
|
||||
extra_cuda_cflags=list(cuda_flags or []) or ["-O3", "-std=c++17"],
|
||||
extra_ldflags=list(ldflags or []) or None,
|
||||
build_directory=build_dir,
|
||||
)
|
||||
163
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
Normal file
163
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST
|
||||
|
||||
|
||||
def dequantize_k_cache(quant_k_cache):
|
||||
if NSA_DEQUANT_K_CACHE_FAST:
|
||||
return _dequantize_k_cache_fast_wrapped(quant_k_cache)
|
||||
else:
|
||||
return _dequantize_k_cache_slow(quant_k_cache)
|
||||
|
||||
|
||||
def _dequantize_k_cache_slow(
|
||||
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
d: int = 576,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
De-quantize the k-cache
|
||||
"""
|
||||
assert dv % tile_size == 0
|
||||
num_tiles = dv // tile_size
|
||||
num_blocks, block_size, h_k, _ = quant_k_cache.shape
|
||||
assert h_k == 1
|
||||
result = torch.empty(
|
||||
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
|
||||
)
|
||||
|
||||
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
|
||||
|
||||
input_nope = quant_k_cache[..., :dv]
|
||||
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
|
||||
result[..., dv:] = input_rope
|
||||
|
||||
for tile_idx in range(0, num_tiles):
|
||||
cur_nope = input_nope[
|
||||
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||
].to(torch.float32)
|
||||
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
|
||||
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||
cur_nope * cur_scales
|
||||
)
|
||||
|
||||
result = result.view(num_blocks, block_size, 1, d)
|
||||
return result
|
||||
|
||||
|
||||
def _dequantize_k_cache_fast_wrapped(
|
||||
quant_k_cache: torch.Tensor,
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
# TODO the final API may be 2D instead of 4D, thus we convert them here
|
||||
num_blocks, block_size, _, dim_quant = quant_k_cache.shape
|
||||
assert dv == 512
|
||||
assert dim_quant == 656
|
||||
assert tile_size == 128
|
||||
quant_k_cache = quant_k_cache.view((-1, dim_quant))
|
||||
|
||||
output = _dequantize_k_cache_fast(quant_k_cache)
|
||||
|
||||
return output.view(num_blocks, block_size, 1, -1)
|
||||
|
||||
|
||||
def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
|
||||
num_tokens, dim_quant = quant_k_cache.shape
|
||||
|
||||
assert quant_k_cache.dtype == torch.float8_e4m3fn
|
||||
dim_nope = 512
|
||||
dim_rope = 64
|
||||
num_tiles = dim_nope // group_size
|
||||
assert dim_quant == 656
|
||||
|
||||
output = torch.empty(
|
||||
(num_tokens, dim_nope + dim_rope),
|
||||
dtype=torch.bfloat16,
|
||||
device=quant_k_cache.device,
|
||||
)
|
||||
|
||||
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
|
||||
assert num_blocks_per_token == 5
|
||||
|
||||
assert dim_nope % group_size == 0
|
||||
NUM_NOPE_BLOCKS = dim_nope // group_size
|
||||
|
||||
input_nope_q = quant_k_cache[:, :dim_nope]
|
||||
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
|
||||
torch.float32
|
||||
)
|
||||
input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
|
||||
|
||||
_dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
|
||||
output,
|
||||
input_nope_q,
|
||||
input_nope_s,
|
||||
input_rope,
|
||||
output.stride(0),
|
||||
input_nope_q.stride(0),
|
||||
input_nope_s.stride(0),
|
||||
input_rope.stride(0),
|
||||
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
|
||||
GROUP_SIZE=group_size,
|
||||
DIM_NOPE=dim_nope,
|
||||
DIM_ROPE=dim_rope,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dequantize_k_cache_fast_kernel(
|
||||
output_ptr,
|
||||
input_nope_q_ptr,
|
||||
input_nope_s_ptr,
|
||||
input_rope_ptr,
|
||||
output_stride_0: int,
|
||||
input_nope_q_stride_0: int,
|
||||
input_nope_s_stride_0: int,
|
||||
input_rope_stride_0: int,
|
||||
NUM_NOPE_BLOCKS: tl.constexpr,
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
DIM_NOPE: tl.constexpr,
|
||||
DIM_ROPE: tl.constexpr,
|
||||
):
|
||||
token_id = tl.program_id(0)
|
||||
raw_block_id = tl.program_id(1)
|
||||
|
||||
if raw_block_id < NUM_NOPE_BLOCKS:
|
||||
# a. dequant nope
|
||||
effective_block_id = raw_block_id
|
||||
|
||||
offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||
mask = offs_q < DIM_NOPE
|
||||
ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q
|
||||
ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id
|
||||
|
||||
y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
|
||||
y_s = tl.load(ptr_s)
|
||||
|
||||
y = (y_q * y_s).to(output_ptr.dtype.element_ty)
|
||||
|
||||
dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
|
||||
tl.store(dst_ptr, y, mask=mask)
|
||||
else:
|
||||
# b. copy rope
|
||||
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
|
||||
|
||||
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||
mask = offs < DIM_ROPE
|
||||
|
||||
src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs
|
||||
dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
|
||||
|
||||
data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
|
||||
tl.store(dst_ptr, data, mask=mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise Exception("UT is in quant_k_cache.py")
|
||||
135
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
Normal file
135
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# fallback_fp8.py
|
||||
# PyTorch fallback implementation for DeepGEMM-like fp8 logits ops
|
||||
from sglang.srt.utils import ceil_div
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def fallback_fp8_mqa_logits(q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
ks: torch.Tensor,
|
||||
ke: torch.Tensor, cost_only: bool = False) -> torch.Tensor:
|
||||
seq_len_kv = kv.shape[0]
|
||||
|
||||
if cost_only:
|
||||
start = ks.clamp(min=0, max=seq_len_kv)
|
||||
end = ke.clamp(min=0, max=seq_len_kv)
|
||||
count_ones_per_row = (end - start).clamp(min=0)
|
||||
return count_ones_per_row.sum()
|
||||
|
||||
k = kv
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= ks[:, None]
|
||||
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < ke[:, None]
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum('mhd,nd->hmn', q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
|
||||
#cost = mask.sum()
|
||||
return logits
|
||||
|
||||
# """
|
||||
# PyTorch fallback for fp8_mqa_logits.
|
||||
# No real fp8 used, just FP32.
|
||||
# Args:
|
||||
# q: (M, H, D) query
|
||||
# k: (N, D) key
|
||||
# weights: (M, H)
|
||||
# ks: (M,) int32
|
||||
# ke: (M,) int32
|
||||
# Returns:
|
||||
# logits: (M, N) with -inf outside of valid range
|
||||
# """
|
||||
# M, H, D = q.shape
|
||||
# N = k[0].shape[0]
|
||||
# logits = torch.full((M, N), float("-inf"), dtype=torch.float32, device=q.device)
|
||||
|
||||
# # for i in range(M):
|
||||
# # start = max(ks[i].item(), 0)
|
||||
# # end = min(ke[i].item(), N)
|
||||
# # if start >= end:
|
||||
# # continue
|
||||
# # qi = q[i] # (H, D)
|
||||
# # ki = k[start:end] # (L, D)
|
||||
# # sim = torch.matmul(qi, ki.T) # (H, L)
|
||||
# # weighted_sim = (sim.relu() * weights[i].unsqueeze(-1)).sum(dim=0) # (L,)
|
||||
# # logits[i, start:end] = weighted_sim
|
||||
# return logits
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def fallback_fp8_paged_mqa_logits(q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int) -> torch.Tensor:
|
||||
|
||||
batch_size, next_n, heads, dim = q.size()
|
||||
num_block, block_size, _, dim = kv_cache.size()
|
||||
logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32)
|
||||
context_lens = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens[i]
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
|
||||
weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
for block_rk in range(ceil_div(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device=q.device)
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None])
|
||||
s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf'))
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf'))
|
||||
return logits
|
||||
|
||||
|
||||
"""
|
||||
PyTorch fallback for fp8_paged_mqa_logits.
|
||||
No real fp8 used, just FP32.
|
||||
Args:
|
||||
q: (B, N, H, D)
|
||||
kv_cache: (num_blocks, block_size, 1, D)
|
||||
weights: (B * N, H)
|
||||
context_lens: (B,)
|
||||
block_tables: (B, max_blocks)
|
||||
max_model_len: int
|
||||
Returns:
|
||||
logits: (B * N, max_model_len)
|
||||
"""
|
||||
B, N, H, D = q.shape
|
||||
block_size = kv_cache.shape[1]
|
||||
logits = torch.full((B * N, max_model_len), float("-inf"), dtype=torch.float32, device=q.device)
|
||||
|
||||
for i in range(B):
|
||||
ctx_len = context_lens[i].item()
|
||||
q_offsets = torch.arange(ctx_len - N, ctx_len, device=q.device)
|
||||
weight_slice = weights[i * N:(i + 1) * N, :].transpose(0, 1).contiguous()
|
||||
|
||||
for br in range((ctx_len + block_size - 1) // block_size):
|
||||
blk_idx = block_tables[i, br].item()
|
||||
if blk_idx < 0:
|
||||
continue
|
||||
qx = q[i] # (N, H, D)
|
||||
kx = kv_cache[blk_idx] # (block_size, 1, D)
|
||||
kx = kx.squeeze(1) # (block_size, D)
|
||||
k_offsets = torch.arange(br * block_size, (br + 1) * block_size, device=q.device)
|
||||
|
||||
mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None]) # (N, block_size)
|
||||
s = torch.where(mask[None, :, :],
|
||||
torch.einsum('nhd,ld->hnl', qx, kx),
|
||||
torch.full((H, N, block_size), float("-inf"), device=q.device))
|
||||
s = s.relu() * weight_slice[..., None]
|
||||
logits_slice = s.sum(dim=0) # (N, block_size)
|
||||
|
||||
mask_block = (k_offsets[None, :] <= q_offsets[:, None])
|
||||
logits[i * N:(i + 1) * N, br * block_size:(br + 1) * block_size] = \
|
||||
torch.where(mask_block, logits_slice, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
354
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Normal file
354
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Normal file
@@ -0,0 +1,354 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||
|
||||
"""
|
||||
k: data, 128 item per token, fp8
|
||||
s: scale, 1 item per token, fp32
|
||||
"""
|
||||
|
||||
|
||||
class GetK:
|
||||
@classmethod
|
||||
def execute(cls, *args, **kwargs):
|
||||
return cls.torch_fast(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def slow(
|
||||
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||
):
|
||||
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
|
||||
seq_len_ = num_pages * pool.page_size
|
||||
index_k_fp8 = torch.empty(
|
||||
(seq_len_, pool.index_head_dim),
|
||||
dtype=torch.uint8,
|
||||
device=pool.device,
|
||||
)
|
||||
for i in range(num_pages):
|
||||
page_index = page_indices[i]
|
||||
index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
|
||||
page_index
|
||||
][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
|
||||
|
||||
return index_k_fp8[:seq_len]
|
||||
|
||||
@classmethod
|
||||
def torch_fast(
|
||||
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||
):
|
||||
"""
|
||||
:param page_indices: (num_pages,), int32
|
||||
:return: (seq_len, index_head_dim), uint8
|
||||
"""
|
||||
|
||||
# can handle per 128B instead of per element
|
||||
|
||||
# page_indices: (num_pages,), element := a page index
|
||||
buf_numel_per_page = buf.shape[1]
|
||||
|
||||
num_k_bytes_per_page = pool.page_size * pool.index_head_dim
|
||||
num_k_bytes_per_token = pool.index_head_dim
|
||||
|
||||
# buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
|
||||
# flat_buf: (whatever,), uint8
|
||||
flat_buf = buf.flatten()
|
||||
|
||||
# flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
|
||||
flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
|
||||
num_k_bytes_per_page, dtype=torch.int32, device="cuda"
|
||||
)[None, :]
|
||||
flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
|
||||
|
||||
out = flat_buf[flat_indices]
|
||||
return out.view(-1, 128)
|
||||
|
||||
|
||||
class GetS:
|
||||
@classmethod
|
||||
def execute(cls, *args, **kwargs):
|
||||
return cls.torch_fast(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def slow(
|
||||
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||
):
|
||||
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
|
||||
seq_len_ = num_pages * pool.page_size
|
||||
assert pool.index_head_dim // pool.quant_block_size == 1
|
||||
index_k_scale_fp8 = torch.empty(
|
||||
(seq_len_, 4),
|
||||
dtype=torch.uint8,
|
||||
device=pool.device,
|
||||
)
|
||||
for i in range(num_pages):
|
||||
page_index = page_indices[i]
|
||||
index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
|
||||
page_index
|
||||
][pool.page_size * pool.index_head_dim :].view(-1, 4)
|
||||
return index_k_scale_fp8[:seq_len]
|
||||
|
||||
@classmethod
|
||||
def torch_fast(
|
||||
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||
):
|
||||
"""
|
||||
:param page_indices: (num_pages,), int32
|
||||
:return: (seq_len, index_head_dim // quant_block_size), uint8
|
||||
"""
|
||||
buf_numel_per_page = buf.shape[1]
|
||||
|
||||
num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
|
||||
num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
|
||||
s_offset_in_page = pool.page_size * pool.index_head_dim
|
||||
|
||||
flat_buf = buf.flatten()
|
||||
flat_indices = (
|
||||
(page_indices * buf_numel_per_page)[:, None]
|
||||
+ torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
|
||||
None, :
|
||||
]
|
||||
+ s_offset_in_page
|
||||
)
|
||||
flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
|
||||
|
||||
out = flat_buf[flat_indices]
|
||||
return out.view(-1, 4)
|
||||
|
||||
|
||||
class SetK:
|
||||
@classmethod
|
||||
def execute(cls, *args, buf, **kwargs):
|
||||
return cls.torch_fast(*args, **kwargs, buf=buf)
|
||||
|
||||
@classmethod
|
||||
def slow(
|
||||
cls,
|
||||
pool: "NSATokenToKVPool",
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k: torch.Tensor,
|
||||
):
|
||||
for i in range(len(loc)):
|
||||
page_index = loc[i] // pool.page_size
|
||||
offset = loc[i] % pool.page_size
|
||||
buf[
|
||||
page_index,
|
||||
offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
|
||||
] = index_k[i].view(torch.uint8)
|
||||
|
||||
@classmethod
|
||||
def torch_fast(
|
||||
cls,
|
||||
pool: "NSATokenToKVPool",
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k: torch.Tensor,
|
||||
):
|
||||
(num_tokens_to_write,) = loc.shape
|
||||
buf_numel_per_page = buf.shape[1]
|
||||
num_k_bytes_per_token = pool.index_head_dim
|
||||
|
||||
# loc: (num_tokens_to_write,), int32, element := the token index to write to
|
||||
loc_page_index = loc // pool.page_size
|
||||
loc_token_offset_in_page = loc % pool.page_size
|
||||
|
||||
flat_buf = buf.flatten()
|
||||
flat_indices = (
|
||||
(loc_page_index * buf_numel_per_page)[:, None]
|
||||
+ (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
|
||||
+ torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
|
||||
None, :
|
||||
]
|
||||
)
|
||||
num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
|
||||
flat_indices = flat_indices.flatten()[:num_k_bytes_total]
|
||||
flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
|
||||
|
||||
|
||||
class SetS:
|
||||
@classmethod
|
||||
def execute(cls, *args, buf, **kwargs):
|
||||
return cls.torch_fast(*args, **kwargs, buf=buf)
|
||||
|
||||
@classmethod
|
||||
def slow(
|
||||
cls,
|
||||
pool: "NSATokenToKVPool",
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k_scale: torch.Tensor,
|
||||
):
|
||||
for i in range(len(loc)):
|
||||
page_index = loc[i] // pool.page_size
|
||||
offset = loc[i] % pool.page_size
|
||||
start = pool.page_size * pool.index_head_dim
|
||||
buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
|
||||
index_k_scale[i].view(torch.uint8)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def torch_fast(
|
||||
cls,
|
||||
pool: "NSATokenToKVPool",
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k_scale: torch.Tensor,
|
||||
):
|
||||
(num_tokens_to_write,) = loc.shape
|
||||
buf_numel_per_page = buf.shape[1]
|
||||
num_s_bytes_per_token = 4
|
||||
s_offset_in_page = pool.page_size * pool.index_head_dim
|
||||
|
||||
# loc: (num_tokens_to_write,), int32, element := the token index to write to
|
||||
loc_page_index = loc // pool.page_size
|
||||
loc_token_offset_in_page = loc % pool.page_size
|
||||
|
||||
flat_buf = buf.flatten()
|
||||
flat_indices = (
|
||||
(loc_page_index * buf_numel_per_page)[:, None]
|
||||
+ s_offset_in_page
|
||||
+ (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
|
||||
+ torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
|
||||
None, :
|
||||
]
|
||||
)
|
||||
number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
|
||||
flat_indices = flat_indices.flatten()[:number_s_bytes_total]
|
||||
flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
|
||||
|
||||
|
||||
class SetKAndS:
|
||||
@classmethod
|
||||
def execute(cls, *args, buf, **kwargs):
|
||||
if 0:
|
||||
# print("SetK, SetS comparison test")
|
||||
buf_cloned = buf.clone()
|
||||
cls.vanilla(*args, **kwargs, buf=buf)
|
||||
cls.triton(*args, **kwargs, buf=buf_cloned)
|
||||
|
||||
def _clear_token_0(target):
|
||||
target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
|
||||
|
||||
_clear_token_0(buf)
|
||||
_clear_token_0(buf_cloned)
|
||||
|
||||
assert torch.all(
|
||||
buf == buf_cloned
|
||||
), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
|
||||
return
|
||||
|
||||
cls.triton(*args, **kwargs, buf=buf)
|
||||
|
||||
@classmethod
|
||||
def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
|
||||
SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
|
||||
SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
|
||||
|
||||
@classmethod
|
||||
def triton(cls, pool, buf, loc, index_k, index_k_scale):
|
||||
_set_k_and_s_triton(
|
||||
buf=buf,
|
||||
loc=loc,
|
||||
index_k=index_k,
|
||||
index_k_scale=index_k_scale,
|
||||
page_size=pool.page_size,
|
||||
)
|
||||
|
||||
|
||||
def _set_k_and_s_triton(
|
||||
buf: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
index_k: torch.Tensor,
|
||||
index_k_scale: torch.Tensor,
|
||||
page_size: int,
|
||||
):
|
||||
"""
|
||||
:param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
|
||||
:param loc: (num_tokens_to_write,), int, element := the token index to write to
|
||||
:param index_k: (num_tokens_to_write, 128 elem), fp8
|
||||
:param index_k_scale: (num_tokens_to_write, 1 elem), fp32
|
||||
:return:
|
||||
"""
|
||||
num_pages, buf_numel_per_page = buf.shape
|
||||
(num_tokens_to_write,) = loc.shape
|
||||
num_tokens_to_write_, index_head_dim = index_k.shape
|
||||
num_tokens_to_write__, scale_dim = index_k_scale.shape
|
||||
assert buf_numel_per_page == 64 * (128 + 4)
|
||||
assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
|
||||
assert index_head_dim == 128
|
||||
assert scale_dim == 1
|
||||
assert page_size == 64
|
||||
|
||||
assert buf.dtype == torch.uint8
|
||||
assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
|
||||
assert index_k.dtype == torch.float8_e4m3fn
|
||||
assert index_k_scale.dtype == torch.float32
|
||||
|
||||
assert buf.is_contiguous()
|
||||
assert loc.is_contiguous()
|
||||
assert index_k.is_contiguous()
|
||||
assert index_k_scale.is_contiguous()
|
||||
|
||||
buf_fp8 = buf.view(torch.float8_e4m3fn)
|
||||
buf_fp32 = buf.view(torch.float32)
|
||||
|
||||
_set_k_and_s_triton_kernel[(num_tokens_to_write,)](
|
||||
buf_fp8,
|
||||
buf_fp32,
|
||||
loc,
|
||||
index_k,
|
||||
index_k_scale,
|
||||
index_k.stride(0),
|
||||
PAGE_SIZE=page_size,
|
||||
BUF_NUMEL_PER_PAGE=buf_numel_per_page,
|
||||
NUM_K_ELEMS_PER_TOKEN=index_head_dim,
|
||||
S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _set_k_and_s_triton_kernel(
|
||||
buf_fp8_ptr,
|
||||
buf_fp32_ptr,
|
||||
loc_ptr,
|
||||
index_k_ptr,
|
||||
index_k_scale_ptr,
|
||||
index_k_ptr_stride_0,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
BUF_NUMEL_PER_PAGE: tl.constexpr,
|
||||
NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
|
||||
S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
|
||||
):
|
||||
token_id = tl.program_id(0)
|
||||
|
||||
loc = tl.load(loc_ptr + token_id)
|
||||
|
||||
in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
|
||||
|
||||
# no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
|
||||
k = tl.load(index_k_ptr + in_k_offsets)
|
||||
k_scale = tl.load(index_k_scale_ptr + token_id)
|
||||
|
||||
loc_page_index = loc // PAGE_SIZE
|
||||
loc_token_offset_in_page = loc % PAGE_SIZE
|
||||
|
||||
out_k_offsets = (
|
||||
loc_page_index * BUF_NUMEL_PER_PAGE
|
||||
+ loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
|
||||
+ tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
|
||||
)
|
||||
|
||||
# "//4" b/c it is fp32 instead of uint8
|
||||
out_s_offset = (
|
||||
loc_page_index * BUF_NUMEL_PER_PAGE // 4
|
||||
+ S_OFFSET_NBYTES_IN_PAGE // 4
|
||||
+ loc_token_offset_in_page
|
||||
)
|
||||
|
||||
tl.store(buf_fp8_ptr + out_k_offsets, k)
|
||||
tl.store(buf_fp32_ptr + out_s_offset, k_scale)
|
||||
709
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Normal file
709
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Normal file
@@ -0,0 +1,709 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from sglang.srt.layers.attention.nsa.fallback_fp8 import fallback_fp8_mqa_logits, fallback_fp8_paged_mqa_logits
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.debug_utils.dumper import dumper
|
||||
from sglang.srt.utils import add_prefix, is_npu
|
||||
|
||||
if not is_npu():
|
||||
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
|
||||
#import deep_gemm
|
||||
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import add_prefix, align, is_cuda
|
||||
|
||||
# try:
|
||||
# import deep_gemm_v32
|
||||
# except ImportError as e:
|
||||
# print("Error when importing deep_gemm_v32, try deep_gemm")
|
||||
# try:
|
||||
# import deep_gemm as deep_gemm_v32
|
||||
# except ImportError as e:
|
||||
# print("Error when importing deep_gemm, skip")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
|
||||
|
||||
|
||||
class BaseIndexerMetadata(ABC):
|
||||
@abstractmethod
|
||||
def get_seqlens_int32(self) -> torch.Tensor:
|
||||
"""
|
||||
Return: (batch_size,) int32 tensor
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_page_table_64(self) -> torch.Tensor:
|
||||
"""
|
||||
Return: (batch_size, num_blocks) int32, page table.
|
||||
The page size of the table is 64.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_seqlens_expanded(self) -> torch.Tensor:
|
||||
"""
|
||||
Return: (sum_extend_seq_len,) int32 tensor
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def topk_transform(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform topk selection on the logits and possibly transform the result.
|
||||
|
||||
NOTE that attention backend may override this function to do some
|
||||
transformation, which means the result of this topk_transform may not
|
||||
be the topk indices of the input logits.
|
||||
|
||||
Return: Anything, since it will be passed to the attention backend
|
||||
for further processing on sparse attention computation.
|
||||
Don't assume it is the topk indices of the input logits.
|
||||
"""
|
||||
|
||||
def hadamard_transform_pytorch(x: torch.Tensor, scale: float) -> torch.Tensor:
|
||||
"""
|
||||
A native PyTorch implementation of the Fast Hadamard Transform that mimics
|
||||
the behavior of the custom CUDA kernel's call signature.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (*, N), where N is a power of 2.
|
||||
scale (float): The normalization factor to multiply the result by.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The Hadamard transformed tensor.
|
||||
"""
|
||||
# Base case for recursion
|
||||
if x.shape[-1] == 1:
|
||||
return x
|
||||
|
||||
# Split the tensor into two halves
|
||||
half_size = x.shape[-1] // 2
|
||||
a = x[..., :half_size]
|
||||
b = x[..., half_size:]
|
||||
|
||||
# Recursive calls
|
||||
a_transformed = hadamard_transform_pytorch(a, scale=1.0) # No scaling in intermediate steps
|
||||
b_transformed = hadamard_transform_pytorch(b, scale=1.0) # No scaling in intermediate steps
|
||||
|
||||
# Combine the results
|
||||
combined = torch.cat([a_transformed + b_transformed, a_transformed - b_transformed], dim=-1)
|
||||
|
||||
# Apply the scale only at the final step
|
||||
return combined * scale
|
||||
|
||||
|
||||
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dtype == torch.bfloat16
|
||||
#from fast_hadamard_transform import hadamard_transform
|
||||
|
||||
hidden_size = x.size(-1)
|
||||
assert (
|
||||
hidden_size & (hidden_size - 1)
|
||||
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
|
||||
return hadamard_transform_pytorch(x, scale=hidden_size**-0.5)
|
||||
|
||||
|
||||
class V32LayerNorm(nn.Module):
|
||||
"""
|
||||
Layer Normalization.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.layer_norm(
|
||||
x.float(), (self.dim,), self.weight, self.bias, self.eps
|
||||
).type_as(x)
|
||||
|
||||
|
||||
class Indexer(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
index_n_heads: int,
|
||||
index_head_dim: int,
|
||||
rope_head_dim: int,
|
||||
index_topk: int,
|
||||
q_lora_rank: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
layer_id: int,
|
||||
scale_fmt: Optional[str],
|
||||
block_size: int = 128,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
prefix: str = "",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.n_heads = index_n_heads
|
||||
self.head_dim = index_head_dim
|
||||
self.rope_head_dim = rope_head_dim
|
||||
self.index_topk = index_topk
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.layer_id = layer_id
|
||||
self.alt_stream = alt_stream
|
||||
if not is_npu():
|
||||
self.sm_count = torch.cuda.get_device_properties(0).multi_processor_count
|
||||
self.half_device_sm_count = align(self.sm_count // 2, 8)
|
||||
|
||||
self.wq_b = ReplicatedLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("wq_b", prefix),
|
||||
)
|
||||
self.wk = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("wk", prefix),
|
||||
)
|
||||
self.k_norm = V32LayerNorm(self.head_dim)
|
||||
# NOTE: weight_proj is not quantized
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
prefix=add_prefix("weights_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
rope_head_dim,
|
||||
rotary_dim=rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta, # type: ignore
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
self.block_size = block_size
|
||||
self.scale_fmt = scale_fmt
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
def _forward_fake(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
):
|
||||
bs = x.shape[0]
|
||||
assert self.index_topk == 2048
|
||||
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
|
||||
None, ...
|
||||
].repeat(bs, 1)
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
assert (
|
||||
forward_batch.extend_seq_lens_cpu is not None
|
||||
and forward_batch.seq_lens_cpu is not None
|
||||
)
|
||||
which = 0
|
||||
for i, (kv_len, qo_len) in enumerate(
|
||||
zip(
|
||||
forward_batch.seq_lens_cpu.tolist(),
|
||||
forward_batch.extend_seq_lens_cpu,
|
||||
strict=True,
|
||||
)
|
||||
):
|
||||
for j in range(kv_len - qo_len, kv_len):
|
||||
ans[which, j + 1 :] = -1
|
||||
which += 1
|
||||
assert which == ans.shape[0]
|
||||
else:
|
||||
assert forward_batch.seq_lens_cpu is not None
|
||||
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
|
||||
ans[i, seq_len:] = -1
|
||||
|
||||
return ans
|
||||
|
||||
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = weights * self.n_heads**-0.5
|
||||
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
||||
return weights
|
||||
|
||||
def _get_q_k_bf16(
|
||||
self,
|
||||
q_lora: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
enable_dual_stream: bool,
|
||||
):
|
||||
|
||||
if enable_dual_stream:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
|
||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||
self.half_device_sm_count
|
||||
):
|
||||
query, _ = self.wq_b(q_lora)
|
||||
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
|
||||
q_rope, _ = torch.split(
|
||||
query,
|
||||
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
key, _ = self.wk(x)
|
||||
key = self.k_norm(key)
|
||||
|
||||
k_rope, _ = torch.split(
|
||||
key,
|
||||
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
query, _ = self.wq_b(q_lora)
|
||||
if dumper._enable:
|
||||
after_wq_b = query.clone()
|
||||
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
|
||||
|
||||
q_rope, _ = torch.split(
|
||||
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
key, _ = self.wk(x)
|
||||
if dumper._enable:
|
||||
after_wk = key.clone()
|
||||
key = self.k_norm(key)
|
||||
if dumper._enable:
|
||||
after_k_norm = key.clone()
|
||||
k_rope, _ = torch.split(
|
||||
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
|
||||
)
|
||||
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
|
||||
query[..., : self.rope_head_dim] = q_rope
|
||||
key[..., : self.rope_head_dim] = k_rope
|
||||
|
||||
if dumper._enable:
|
||||
q_before_hadamard = query.clone()
|
||||
k_before_hadamard = key.clone()
|
||||
|
||||
if enable_dual_stream:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
query = rotate_activation(query)
|
||||
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
key = rotate_activation(key)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
query = rotate_activation(query)
|
||||
key = rotate_activation(key)
|
||||
|
||||
return query, key
|
||||
|
||||
def _get_topk_paged(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
q_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
metadata: BaseIndexerMetadata,
|
||||
) -> torch.Tensor:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||
|
||||
page_size = forward_batch.token_to_kv_pool.page_size
|
||||
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm_v32
|
||||
assert page_size == 64, "only support page size 64"
|
||||
|
||||
# NOTE(dark): this support extend/decode/decode+graph
|
||||
block_tables = metadata.get_page_table_64()
|
||||
|
||||
max_seq_len = block_tables.shape[1] * page_size
|
||||
kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
|
||||
layer_id=layer_id
|
||||
)
|
||||
|
||||
blocksize = page_size
|
||||
seqlens_32 = metadata.get_seqlens_int32()
|
||||
# NOTE(dark): 132 is SM count on H200/B200, not magic number
|
||||
# schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
|
||||
# seqlens_32, blocksize, self.sm_count
|
||||
# )
|
||||
|
||||
assert len(q_fp8.shape) == 3
|
||||
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
|
||||
assert len(kv_cache_fp8.shape) == 2
|
||||
block_kv = 64
|
||||
num_heads_kv = 1
|
||||
head_dim_with_sf = 132
|
||||
kv_cache_fp8 = kv_cache_fp8.view(
|
||||
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
|
||||
)
|
||||
assert len(weights.shape) == 3
|
||||
weights = weights.squeeze(2)
|
||||
|
||||
logits = fallback_fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
seqlens_32,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
# NOTE(dark): logits should be cleaned in topk_transform
|
||||
topk_result = metadata.topk_transform(logits, self.index_topk)
|
||||
return topk_result
|
||||
|
||||
def _get_topk_ragged(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
q_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
metadata: BaseIndexerMetadata,
|
||||
) -> torch.Tensor:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||
|
||||
page_size = forward_batch.token_to_kv_pool.page_size
|
||||
assert page_size == 64, "only support page size 64"
|
||||
assert len(weights.shape) == 3
|
||||
weights = weights.squeeze(-1)
|
||||
k_fp8_list = []
|
||||
k_scale_list = []
|
||||
ks_list = []
|
||||
offset = 0
|
||||
|
||||
block_tables = metadata.get_page_table_64()
|
||||
|
||||
assert (
|
||||
forward_batch.seq_lens_cpu is not None
|
||||
and forward_batch.extend_seq_lens_cpu is not None
|
||||
)
|
||||
|
||||
for i in range(forward_batch.batch_size):
|
||||
seq_len = forward_batch.seq_lens_cpu[i].item()
|
||||
assert isinstance(seq_len, int)
|
||||
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
|
||||
layer_id,
|
||||
seq_len,
|
||||
block_tables[i],
|
||||
)
|
||||
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
|
||||
layer_id,
|
||||
seq_len,
|
||||
block_tables[i],
|
||||
)
|
||||
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
|
||||
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
|
||||
k_fp8_list.append(k_fp8)
|
||||
k_scale_list.append(k_scale)
|
||||
ks_list.append(ks)
|
||||
offset += extend_seq_len
|
||||
|
||||
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
|
||||
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
|
||||
kv_fp8 = (k_fp8, k_scale)
|
||||
ks = torch.cat(ks_list, dim=0)
|
||||
seq_lens_expanded = metadata.get_seqlens_expanded()
|
||||
ke = ks + seq_lens_expanded
|
||||
|
||||
logits = fallback_fp8_mqa_logits(
|
||||
q_fp8,
|
||||
k_fp8,
|
||||
weights,
|
||||
ks,
|
||||
ke
|
||||
)
|
||||
|
||||
assert logits.shape[0] == len(seq_lens_expanded)
|
||||
topk_result = metadata.topk_transform(logits, self.index_topk)
|
||||
|
||||
return topk_result
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||
|
||||
metadata = forward_batch.attn_backend.get_indexer_metadata(
|
||||
layer_id, forward_batch
|
||||
)
|
||||
|
||||
enable_dual_stream = (
|
||||
NSA_DUAL_STREAM
|
||||
and self.alt_stream is not None
|
||||
and get_is_capture_mode()
|
||||
and q_lora.shape[0] > 0
|
||||
and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
|
||||
# skip NSA if attention backend choose to skip this batch
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
if not NSA_USE_REAL_INDEXER: # temporary
|
||||
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
|
||||
|
||||
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
|
||||
|
||||
q_fp8 = query.to(torch.float32)
|
||||
k_fp8 = key.to(torch.float32)
|
||||
q_scale = torch.ones((query.shape[0], 1), dtype=torch.float32, device="cuda")
|
||||
k_scale = torch.ones((key.shape[0], 1), dtype=torch.float32, device="cuda")
|
||||
|
||||
if enable_dual_stream:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
|
||||
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
|
||||
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
|
||||
|
||||
# k_fp8: (seq_len, head_dim) fp8_e4m3fn
|
||||
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
|
||||
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
|
||||
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
|
||||
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
|
||||
layer_id=layer_id,
|
||||
loc=forward_batch.out_cache_loc,
|
||||
index_k=k_fp8,
|
||||
index_k_scale=k_scale,
|
||||
)
|
||||
|
||||
weights = self._get_logits_head_gate(x, q_scale)
|
||||
|
||||
assert forward_batch.seq_lens_cpu is not None
|
||||
if len(forward_batch.seq_lens_cpu) == 0:
|
||||
# this seems b/c max-pad, no worries?
|
||||
# if x.shape[0] != 0:
|
||||
# print(
|
||||
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
|
||||
# )
|
||||
return torch.full(
|
||||
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
|
||||
)
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
topk_result = self._get_topk_paged(
|
||||
forward_batch, layer_id, q_fp8, weights, metadata
|
||||
)
|
||||
else:
|
||||
topk_result = self._get_topk_ragged(
|
||||
forward_batch, layer_id, q_fp8, weights, metadata
|
||||
)
|
||||
|
||||
return topk_result
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self._forward(x, q_lora, positions, forward_batch, layer_id)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
q_lora: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: int,
|
||||
) -> torch.Tensor:
|
||||
import custom_ops
|
||||
import torch_npu
|
||||
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
|
||||
actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
|
||||
else:
|
||||
actual_seq_lengths_kv = (
|
||||
forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
|
||||
)
|
||||
enable_index_cp = (
|
||||
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
|
||||
)
|
||||
is_prefill = forward_batch.forward_mode.is_extend()
|
||||
|
||||
attention_tp_rank = get_attention_tp_rank()
|
||||
attention_tp_size = get_attention_tp_size()
|
||||
|
||||
cos_sin = self.rotary_emb.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
|
||||
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
|
||||
if is_prefill and enable_index_cp:
|
||||
slice_length = cos.shape[0] // attention_tp_size
|
||||
cos = cos[
|
||||
slice_length
|
||||
* attention_tp_rank : slice_length
|
||||
* (attention_tp_rank + 1)
|
||||
]
|
||||
sin = sin[
|
||||
slice_length
|
||||
* attention_tp_rank : slice_length
|
||||
* (attention_tp_rank + 1)
|
||||
]
|
||||
|
||||
slot_mapping = forward_batch.out_cache_loc
|
||||
block_table = forward_batch.attn_backend.forward_metadata.block_tables
|
||||
|
||||
bs = x.shape[0]
|
||||
|
||||
q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
|
||||
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
|
||||
q_pe, q_nope = torch.split(
|
||||
q,
|
||||
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||
dim=-1,
|
||||
) # [bs, 64, 64 + 64]
|
||||
|
||||
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
|
||||
bs, self.n_heads, self.rope_head_dim
|
||||
) # [bs, n, d]
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
|
||||
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
|
||||
k = self.k_norm(k_proj)
|
||||
k_pe, k_nope = torch.split(
|
||||
k,
|
||||
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||
dim=-1,
|
||||
) # [bs, 64 + 64]
|
||||
|
||||
k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
|
||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
|
||||
bs, 1, self.rope_head_dim
|
||||
) # [bs, 1, d]
|
||||
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
|
||||
|
||||
if is_prefill and enable_index_cp:
|
||||
k, local_k = (
|
||||
torch.empty(
|
||||
(k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
|
||||
dtype=k.dtype,
|
||||
device=k.device,
|
||||
),
|
||||
k,
|
||||
)
|
||||
get_attention_tp_group().all_gather_into_tensor(k, local_k)
|
||||
|
||||
forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
|
||||
|
||||
indexer_input = {}
|
||||
if is_prefill:
|
||||
actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
|
||||
actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
|
||||
device=q.device
|
||||
)
|
||||
if enable_index_cp:
|
||||
actual_seq_lengths_q -= bs * attention_tp_rank
|
||||
actual_seq_lengths_q = torch.max(
|
||||
actual_seq_lengths_q,
|
||||
torch.zeros_like(actual_seq_lengths_q).to(
|
||||
device=actual_seq_lengths_q.device
|
||||
),
|
||||
)
|
||||
actual_seq_lengths_q = torch.min(
|
||||
actual_seq_lengths_q,
|
||||
torch.full(actual_seq_lengths_q.shape, bs).to(
|
||||
device=actual_seq_lengths_q.device
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
|
||||
actual_seq_lengths_q = torch.tensor(
|
||||
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
|
||||
)
|
||||
else:
|
||||
actual_seq_lengths_q = (
|
||||
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
|
||||
)
|
||||
|
||||
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
|
||||
|
||||
x = x.view(-1, self.hidden_size)
|
||||
weights = self.weights_proj(x)[0]
|
||||
block_table = (
|
||||
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
|
||||
)
|
||||
|
||||
topk_indices = torch.ops.custom.npu_lightning_indexer(
|
||||
query=q.view(-1, self.n_heads, self.head_dim),
|
||||
key=past_key_states,
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
|
||||
actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
|
||||
block_table=block_table,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=self.index_topk,
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
if is_prefill and enable_index_cp:
|
||||
topk_indices, local_topk_indices = (
|
||||
torch.empty(
|
||||
(
|
||||
topk_indices.shape[0] * attention_tp_size,
|
||||
topk_indices.shape[1],
|
||||
topk_indices.shape[2],
|
||||
),
|
||||
dtype=topk_indices.dtype,
|
||||
device=topk_indices.device,
|
||||
),
|
||||
topk_indices,
|
||||
)
|
||||
get_attention_tp_group().all_gather_into_tensor(
|
||||
topk_indices, local_topk_indices
|
||||
)
|
||||
|
||||
return topk_indices
|
||||
255
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
Normal file
255
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST
|
||||
|
||||
|
||||
def quantize_k_cache(cache_k):
|
||||
# TODO upstream can skip concat([k_nope, k_pe]) since we split them here
|
||||
if NSA_QUANT_K_CACHE_FAST:
|
||||
return _quantize_k_cache_fast_wrapped(cache_k)
|
||||
else:
|
||||
return _quantize_k_cache_slow(cache_k)
|
||||
|
||||
|
||||
# Copied from original
|
||||
def _quantize_k_cache_slow(
|
||||
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Quantize the k-cache
|
||||
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
|
||||
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
|
||||
"""
|
||||
assert dv % tile_size == 0
|
||||
num_tiles = dv // tile_size
|
||||
num_blocks, block_size, h_k, d = input_k_cache.shape
|
||||
assert h_k == 1
|
||||
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
|
||||
input_elem_size = input_k_cache.element_size()
|
||||
|
||||
result = torch.empty(
|
||||
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device=input_k_cache.device,
|
||||
)
|
||||
result_k_nope_part = result[..., :dv]
|
||||
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
|
||||
result_k_rope_part[:] = input_k_cache[..., dv:]
|
||||
|
||||
for tile_idx in range(0, num_tiles):
|
||||
cur_scale_factors_inv = (
|
||||
torch.abs(
|
||||
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
|
||||
)
|
||||
.max(dim=-1)
|
||||
.values
|
||||
/ 448.0
|
||||
) # [num_blocks, block_size]
|
||||
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
|
||||
|
||||
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
|
||||
cur_quantized_nope = (
|
||||
input_k_cache[
|
||||
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||
].float()
|
||||
/ cur_scale_factors_inv.float()
|
||||
).to(torch.float8_e4m3fn)
|
||||
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||
cur_quantized_nope
|
||||
)
|
||||
|
||||
result = result.view(num_blocks, block_size, 1, -1)
|
||||
return result
|
||||
|
||||
|
||||
def _quantize_k_cache_fast_wrapped(
|
||||
input_k_cache: torch.Tensor,
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
# TODO the final API may be 2D instead of 4D, thus we convert them here
|
||||
num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape
|
||||
assert dv == 512
|
||||
assert dim_nope_and_rope == 512 + 64
|
||||
assert tile_size == 128
|
||||
input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))
|
||||
|
||||
# TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
|
||||
k_nope = input_k_cache[:, :dv]
|
||||
k_rope = input_k_cache[:, dv:]
|
||||
|
||||
output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)
|
||||
|
||||
return output.view(num_blocks, block_size, 1, -1)
|
||||
|
||||
|
||||
def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):
|
||||
"""
|
||||
:param k_nope: (num_tokens, dim_nope 512)
|
||||
:param k_rope: (num_tokens, dim_rope 64)
|
||||
"""
|
||||
|
||||
assert k_nope.dtype == torch.bfloat16
|
||||
assert k_rope.dtype == torch.bfloat16
|
||||
|
||||
num_tokens, dim_nope = k_nope.shape
|
||||
num_tokens_, dim_rope = k_rope.shape
|
||||
assert num_tokens == num_tokens_
|
||||
assert dim_nope == 512
|
||||
assert dim_rope == 64
|
||||
assert k_nope.dtype == k_rope.dtype
|
||||
num_tiles = dim_nope // group_size
|
||||
|
||||
assert k_nope.stride(1) == 1
|
||||
assert k_rope.stride(1) == 1
|
||||
|
||||
output = torch.empty(
|
||||
(num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device=k_nope.device,
|
||||
)
|
||||
output_nope_q = output[..., :dim_nope]
|
||||
output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)
|
||||
output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)
|
||||
|
||||
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
|
||||
assert num_blocks_per_token == 5
|
||||
|
||||
assert dim_nope % group_size == 0
|
||||
NUM_NOPE_BLOCKS = dim_nope // group_size
|
||||
|
||||
_quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
|
||||
output_nope_q,
|
||||
output_nope_s,
|
||||
output_rope,
|
||||
k_nope,
|
||||
k_rope,
|
||||
output_nope_q.stride(0),
|
||||
output_nope_s.stride(0),
|
||||
output_rope.stride(0),
|
||||
k_nope.stride(0),
|
||||
k_rope.stride(0),
|
||||
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
|
||||
GROUP_SIZE=group_size,
|
||||
DIM_NOPE=dim_nope,
|
||||
DIM_ROPE=dim_rope,
|
||||
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
|
||||
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _quantize_k_cache_fast_kernel(
|
||||
output_nope_q_ptr,
|
||||
output_nope_s_ptr,
|
||||
output_rope_ptr,
|
||||
k_nope_ptr,
|
||||
k_rope_ptr,
|
||||
output_nope_q_stride_0: int,
|
||||
output_nope_s_stride_0: int,
|
||||
output_rope_stride_0: int,
|
||||
k_nope_stride_0: int,
|
||||
k_rope_stride_0: int,
|
||||
NUM_NOPE_BLOCKS: tl.constexpr,
|
||||
GROUP_SIZE: tl.constexpr,
|
||||
DIM_NOPE: tl.constexpr,
|
||||
DIM_ROPE: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr,
|
||||
FP8_MAX: tl.constexpr,
|
||||
):
|
||||
token_id = tl.program_id(0)
|
||||
raw_block_id = tl.program_id(1)
|
||||
|
||||
if raw_block_id < NUM_NOPE_BLOCKS:
|
||||
# a. quant nope
|
||||
effective_block_id = raw_block_id
|
||||
|
||||
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||
mask = offs < DIM_NOPE
|
||||
ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs
|
||||
|
||||
y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
# the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
|
||||
y_s = tl.max(tl.abs(y)) / FP8_MAX
|
||||
y_s_inv = 1.0 / y_s
|
||||
y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(
|
||||
output_nope_q_ptr.dtype.element_ty
|
||||
)
|
||||
|
||||
dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs
|
||||
dst_s_ptr = (
|
||||
output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id
|
||||
)
|
||||
|
||||
tl.store(dst_q_ptr, y_q, mask=mask)
|
||||
tl.store(dst_s_ptr, y_s)
|
||||
else:
|
||||
# b. copy rope
|
||||
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
|
||||
|
||||
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||
mask = offs < DIM_ROPE
|
||||
|
||||
src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs
|
||||
dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs
|
||||
|
||||
data = tl.load(src_ptr, mask=mask)
|
||||
tl.store(dst_ptr, data, mask=mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for num_blocks, block_size in [
|
||||
(1, 1),
|
||||
(10, 64),
|
||||
]:
|
||||
dim_nope_and_rope = 512 + 64
|
||||
|
||||
input_k_cache = torch.randn(
|
||||
(num_blocks, block_size, 1, dim_nope_and_rope),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
# temp debug
|
||||
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
|
||||
|
||||
ref_quant = _quantize_k_cache_slow(input_k_cache)
|
||||
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
|
||||
# print(f"{input_k_cache=}")
|
||||
# print(f"{ref_quant=}")
|
||||
# print(f"{actual_quant=}")
|
||||
# print(f"{ref_quant == actual_quant=}")
|
||||
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
|
||||
# print(f"{ref_quant.view(torch.bfloat16)=}")
|
||||
# print(f"{actual_quant.view(torch.bfloat16)=}")
|
||||
# assert torch.all(ref_quant == actual_quant)
|
||||
|
||||
import dequant_k_cache
|
||||
|
||||
ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
|
||||
ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
|
||||
actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(
|
||||
actual_quant
|
||||
)
|
||||
|
||||
print(f"{ref_ref_dequant=}")
|
||||
print(f"{actual_actual_dequant=}")
|
||||
print(f"{actual_actual_dequant - ref_ref_dequant=}")
|
||||
print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}")
|
||||
|
||||
# TODO too different?
|
||||
torch.testing.assert_close(
|
||||
ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
|
||||
)
|
||||
|
||||
print("Passed")
|
||||
813
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
Normal file
813
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
Normal file
@@ -0,0 +1,813 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
# import tilelang
|
||||
# import tilelang.language as T
|
||||
import torch
|
||||
|
||||
# tilelang.set_log_level("WARNING")
|
||||
|
||||
# pass_configs = {
|
||||
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
# tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
|
||||
# }
|
||||
|
||||
BF16 = "bfloat16"
|
||||
FP8 = "float8_e4m3"
|
||||
FP32 = "float32"
|
||||
|
||||
'''
|
||||
def fast_log2_ceil(x):
|
||||
bits_x = T.reinterpret("uint32", x)
|
||||
exp_x = (bits_x >> 23) & 0xFF
|
||||
man_bits = bits_x & ((1 << 23) - 1)
|
||||
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
||||
|
||||
|
||||
def fast_pow2(x):
|
||||
bits_x = (x + 127) << 23
|
||||
return T.reinterpret("float32", bits_x)
|
||||
|
||||
|
||||
def fast_round_scale(amax, fp8_max_inv):
|
||||
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
||||
|
||||
@tilelang.jit(pass_configs=pass_configs)
|
||||
def act_quant_kernel(
|
||||
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
|
||||
):
|
||||
M = T.symbolic("M")
|
||||
fp8_min = -448.0
|
||||
fp8_max = 448.0
|
||||
fp8_max_inv = 1 / fp8_max
|
||||
num_stages = 0 if round_scale else 2
|
||||
blk_m = 32
|
||||
group_size = 128
|
||||
|
||||
@T.prim_func
|
||||
def act_quant_kernel_(
|
||||
X: T.Tensor[(M, N), in_dtype],
|
||||
Y: T.Tensor[(M, N), out_dtype],
|
||||
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
||||
):
|
||||
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
||||
pid_m,
|
||||
pid_n,
|
||||
):
|
||||
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
||||
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
||||
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
|
||||
s_local = T.alloc_fragment((blk_m,), scale_dtype)
|
||||
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
||||
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
||||
|
||||
for _ in T.Pipelined(1, num_stages=num_stages):
|
||||
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
||||
T.copy(x_shared, x_local)
|
||||
T.reduce_absmax(x_local, amax_local, dim=1)
|
||||
for i in T.Parallel(blk_m):
|
||||
amax_local[i] = T.max(amax_local[i], 1e-4)
|
||||
if round_scale:
|
||||
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
||||
else:
|
||||
s_local[i] = amax_local[i] * fp8_max_inv
|
||||
for i, j in T.Parallel(blk_m, group_size):
|
||||
y_local[i, j] = T.clamp(
|
||||
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
||||
)
|
||||
for i in T.Parallel(blk_m):
|
||||
S[pid_m * blk_m + i, pid_n] = s_local[i]
|
||||
T.copy(y_local, y_shared)
|
||||
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
||||
|
||||
return act_quant_kernel_
|
||||
|
||||
def act_quant(
|
||||
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantizes the input tensor `x` using block-wise quantization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||
- A tensor of scaling factors with dtype `torch.float32`.
|
||||
"""
|
||||
assert x.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert (
|
||||
x.size(-1) % block_size == 0
|
||||
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
|
||||
N = x.size(-1)
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
|
||||
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
|
||||
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
||||
return y, s
|
||||
|
||||
|
||||
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
|
||||
def fp8_index_kernel(h: int, d: int):
|
||||
b = T.symbolic("b")
|
||||
m = T.symbolic("m")
|
||||
n = T.symbolic("n")
|
||||
|
||||
blk_n1 = 512
|
||||
blk_n2 = 128
|
||||
|
||||
@T.prim_func
|
||||
def fp8_index_kernel_(
|
||||
q: T.Tensor[(b, m, h, d), FP8],
|
||||
q_s: T.Tensor[(b, m, h), FP32],
|
||||
k: T.Tensor[(b, n, d), FP8],
|
||||
k_s: T.Tensor[(b, n), FP32],
|
||||
o: T.Tensor[(b, m, n), FP32],
|
||||
) -> None:
|
||||
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
|
||||
q_smem = T.alloc_shared((h, d), FP8)
|
||||
T.copy(q[i_b, i_m, 0, 0], q_smem)
|
||||
|
||||
q_s_frag = T.alloc_fragment(h, FP32)
|
||||
T.copy(q_s[i_b, i_m, 0], q_s_frag)
|
||||
|
||||
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
|
||||
k_smem = T.alloc_shared((blk_n2, d), FP8)
|
||||
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
|
||||
|
||||
k_s_frag = T.alloc_fragment(blk_n2, FP32)
|
||||
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
|
||||
|
||||
logits = T.alloc_fragment((blk_n2, h), FP32)
|
||||
T.gemm(
|
||||
k_smem,
|
||||
q_smem,
|
||||
logits,
|
||||
transpose_A=False,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
|
||||
for i_h, i3_n in T.Parallel(h, blk_n2):
|
||||
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
|
||||
|
||||
logits_sum = T.alloc_fragment(blk_n2, FP32)
|
||||
T.reduce_sum(logits, logits_sum, dim=1)
|
||||
|
||||
for i3_n in T.Parallel(blk_n2):
|
||||
logits_sum[i3_n] *= k_s_frag[i3_n]
|
||||
|
||||
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
|
||||
|
||||
return fp8_index_kernel_
|
||||
|
||||
|
||||
def fp8_index(
|
||||
q: torch.Tensor,
|
||||
q_s: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
k_s: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform index score using FP8 precision.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): The Q tensor, must be contiguous.
|
||||
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
|
||||
k (torch.Tensor): The K tensor, must be contiguous.
|
||||
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
|
||||
|
||||
fp8 q @ fp8 k -> fp32 logits
|
||||
relu(fp32 logits) * q_s (weights) -> fp32 logits
|
||||
fp32 logits -> fp32 logits_sum
|
||||
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
|
||||
"""
|
||||
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
out_idx=[-1],
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
},
|
||||
)
|
||||
def sparse_attention_fwd_kernel_v1(
|
||||
num_heads,
|
||||
dim,
|
||||
tail_dim,
|
||||
topk,
|
||||
*,
|
||||
kv_group=1,
|
||||
sm_scale=None,
|
||||
is_causal=True,
|
||||
block_I=64,
|
||||
num_stages=2,
|
||||
threads=256,
|
||||
):
|
||||
assert dim == tilelang.math.next_power_of_2(
|
||||
dim
|
||||
), f"haven't check padding correctness yet, dim={dim}"
|
||||
assert tail_dim == tilelang.math.next_power_of_2(
|
||||
tail_dim
|
||||
), f"haven't check padding correctness yet, dim={tail_dim}"
|
||||
assert is_causal == True, "non-casual is not supported"
|
||||
assert (
|
||||
topk % block_I == 0
|
||||
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
|
||||
if sm_scale is None:
|
||||
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
|
||||
else:
|
||||
sm_scale = sm_scale * 1.44269504 # log2(e)
|
||||
|
||||
batch = T.symbolic("batch")
|
||||
seq_len = T.symbolic("seq_len")
|
||||
seq_len_kv = T.symbolic("seq_len_kv")
|
||||
|
||||
head_kv = num_heads // kv_group
|
||||
q_shape = [batch, seq_len, num_heads, dim + tail_dim]
|
||||
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
|
||||
o_shape = [batch, seq_len, num_heads, dim]
|
||||
indices_shape = [batch, seq_len, kv_group, topk]
|
||||
indices_dtype = "int32"
|
||||
dtype = "bfloat16"
|
||||
accum_dtype = "float"
|
||||
|
||||
H = head_kv
|
||||
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
|
||||
if padded_H != H:
|
||||
assert kv_group == 1
|
||||
BI = block_I
|
||||
NI = tilelang.cdiv(topk, block_I)
|
||||
D = dim
|
||||
D_tail = tail_dim
|
||||
|
||||
if head_kv > 64:
|
||||
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
|
||||
REPLICATE_H = head_kv // 64
|
||||
else:
|
||||
REPLICATE_H = 1
|
||||
|
||||
H_per_block = padded_H if REPLICATE_H == 1 else 64
|
||||
|
||||
@T.prim_func
|
||||
def main(
|
||||
Q: T.Tensor(q_shape, dtype), # type: ignore
|
||||
KV: T.Tensor(kv_shape, dtype), # type: ignore
|
||||
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
|
||||
Output: T.Tensor(o_shape, dtype), # type: ignore
|
||||
):
|
||||
with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
|
||||
bx,
|
||||
by,
|
||||
bz,
|
||||
):
|
||||
Q_shared = T.alloc_shared([H_per_block, D], dtype)
|
||||
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
|
||||
KV_shared = T.alloc_shared([BI, D], dtype)
|
||||
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
|
||||
O_shared = T.alloc_shared([H_per_block, D], dtype)
|
||||
mask = T.alloc_fragment([BI], "bool")
|
||||
|
||||
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
|
||||
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
|
||||
S_shared = T.alloc_shared([H_per_block, BI], dtype)
|
||||
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
alpha = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
m_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
|
||||
T.fill(acc_o, 0)
|
||||
T.fill(sumexp, 0)
|
||||
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
|
||||
|
||||
b_i, g_i = by, bz
|
||||
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
|
||||
q_i = s_i
|
||||
max_kv_i = q_i
|
||||
|
||||
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
|
||||
H1 = H0 + H_per_block
|
||||
|
||||
T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
|
||||
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
|
||||
|
||||
for i_i in T.Pipelined(NI, num_stages=num_stages):
|
||||
|
||||
for bi_i in T.Parallel(BI):
|
||||
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0
|
||||
|
||||
for bi_i, d_i in T.Parallel(BI, D):
|
||||
KV_shared[bi_i, d_i] = KV[
|
||||
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i
|
||||
]
|
||||
for bi_i, d_i in T.Parallel(BI, D_tail):
|
||||
K_tail_shared[bi_i, d_i] = KV[
|
||||
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i
|
||||
]
|
||||
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.if_then_else(
|
||||
mask[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared,
|
||||
KV_shared,
|
||||
acc_s,
|
||||
transpose_B=True,
|
||||
policy=T.GemmWarpPolicy.FullCol,
|
||||
)
|
||||
T.gemm(
|
||||
Q_tail_shared,
|
||||
K_tail_shared,
|
||||
acc_s,
|
||||
transpose_B=True,
|
||||
policy=T.GemmWarpPolicy.FullCol,
|
||||
)
|
||||
T.copy(m_i, m_i_prev)
|
||||
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.exp2(
|
||||
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||
)
|
||||
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
|
||||
for h_i, d_i in T.Parallel(H_per_block, D):
|
||||
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
|
||||
|
||||
T.copy(acc_s, S_shared)
|
||||
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
|
||||
|
||||
# Rescale
|
||||
for h_i, d_i in T.Parallel(H_per_block, D):
|
||||
acc_o[h_i, d_i] /= sumexp[h_i]
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
|
||||
|
||||
T.copy(acc_o, O_shared)
|
||||
T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
|
||||
|
||||
return main
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
out_idx=[-1],
|
||||
compile_flags=[
|
||||
"-O3",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--ptxas-options=-v,--register-usage-level=10",
|
||||
"-DNDEBUG",
|
||||
],
|
||||
) # type: ignore
|
||||
def sparse_attention_fwd_kernel_v2(
|
||||
num_heads: int,
|
||||
dim: int,
|
||||
tail_dim: int,
|
||||
topk: int,
|
||||
*,
|
||||
kv_group: int = 1,
|
||||
sm_scale: Optional[float] = None,
|
||||
block_I: int = 64,
|
||||
):
|
||||
assert dim == tilelang.math.next_power_of_2(
|
||||
dim
|
||||
), f"haven't check padding correctness yet, dim={dim}"
|
||||
assert tail_dim == tilelang.math.next_power_of_2(
|
||||
tail_dim
|
||||
), f"haven't check padding correctness yet, dim={tail_dim}"
|
||||
assert (
|
||||
topk % block_I == 0
|
||||
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
|
||||
if sm_scale is None:
|
||||
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
|
||||
else:
|
||||
sm_scale = sm_scale * 1.44269504 # log2(e)
|
||||
threads = 384
|
||||
|
||||
batch = T.symbolic("batch")
|
||||
qo_len = T.symbolic("seq_len")
|
||||
num_pages = T.symbolic("num_pages")
|
||||
|
||||
q_shape = [batch, qo_len, num_heads, dim + tail_dim]
|
||||
kv_shape = [batch, num_pages, kv_group, dim + tail_dim]
|
||||
o_shape = [batch, qo_len, num_heads, dim]
|
||||
indices_shape = [batch, qo_len, kv_group, topk]
|
||||
|
||||
indices_dtype = "int32"
|
||||
dtype = "bfloat16"
|
||||
accum_dtype = "float"
|
||||
|
||||
H = num_heads
|
||||
padded_H = max(tilelang.math.next_power_of_2(num_heads), 16)
|
||||
if padded_H != H:
|
||||
assert kv_group == 1
|
||||
BI = block_I
|
||||
NI = tilelang.cdiv(topk, block_I)
|
||||
assert NI % 2 == 0, "NI should be a multiple of 2"
|
||||
D = dim
|
||||
D_tail = tail_dim
|
||||
if num_heads > 64:
|
||||
assert num_heads % 64 == 0, "head_kv should be a multiple of 64"
|
||||
REPLICATE_H = num_heads // 64
|
||||
else:
|
||||
REPLICATE_H = 1
|
||||
|
||||
H_per_block = padded_H if REPLICATE_H == 1 else 64
|
||||
|
||||
@T.prim_func
|
||||
def main(
|
||||
Q: T.Tensor(q_shape, dtype), # type: ignore
|
||||
KV: T.Tensor(kv_shape, dtype), # type: ignore
|
||||
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
|
||||
Output: T.Tensor(o_shape, dtype), # type: ignore
|
||||
):
|
||||
"""
|
||||
Q: [b, qo_len, H, D + D_tail] (bfloat16)
|
||||
KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)
|
||||
Indices: [b, qo_len, kv_group, topk] (int32)
|
||||
"""
|
||||
|
||||
with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz): # type: ignore
|
||||
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
|
||||
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
|
||||
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
|
||||
KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)
|
||||
KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)
|
||||
KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)
|
||||
KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)
|
||||
K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)
|
||||
K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)
|
||||
O_shared_l = Q_shared_l
|
||||
O_shared_r = Q_shared_r
|
||||
is_kv_valid_0 = T.alloc_shared([BI], "bool", scope="shared")
|
||||
is_kv_valid_1 = T.alloc_shared([BI], "bool", scope="shared")
|
||||
|
||||
acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
|
||||
acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
|
||||
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
|
||||
S_shared = T.alloc_shared([H_per_block, BI], dtype)
|
||||
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)
|
||||
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared")
|
||||
alpha_local = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
m_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
|
||||
indices_local = T.alloc_local([1], indices_dtype)
|
||||
indices_tmp = T.alloc_local([1], indices_dtype)
|
||||
|
||||
bar_q = T.alloc_barrier(arrive_count=384)
|
||||
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
|
||||
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
|
||||
bar_k_0_free = T.alloc_barrier(arrive_count=256)
|
||||
bar_k_1_free = T.alloc_barrier(arrive_count=256)
|
||||
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
|
||||
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
|
||||
|
||||
bar_0_128 = T.alloc_barrier(arrive_count=128)
|
||||
bar_1_128 = T.alloc_barrier(arrive_count=128)
|
||||
bar_2_128 = T.alloc_barrier(arrive_count=128)
|
||||
bar_final = T.alloc_barrier(arrive_count=128)
|
||||
|
||||
b_i, g_i = by, bz
|
||||
s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H
|
||||
|
||||
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
|
||||
H1 = H0 + H_per_block
|
||||
|
||||
tx = T.get_thread_binding()
|
||||
|
||||
T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
|
||||
T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
|
||||
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
|
||||
T.barrier_arrive(bar_q)
|
||||
|
||||
if tx < 128:
|
||||
T.set_max_nreg(240, 1)
|
||||
T.fill(sumexp, 0)
|
||||
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
|
||||
T.fill(acc_o_l, 0)
|
||||
T.barrier_wait(bar_q, 0)
|
||||
|
||||
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||
# Buffer 0
|
||||
# with sync_at(bar_0_128, 0):
|
||||
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
|
||||
T.barrier_arrive(bar_0_128)
|
||||
T.barrier_wait(bar_0_128, 0)
|
||||
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.if_then_else(
|
||||
is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1
|
||||
)
|
||||
T.gemm(
|
||||
Q_tail_shared,
|
||||
K_tail_shared_0,
|
||||
acc_s,
|
||||
transpose_B=True,
|
||||
wg_wait=-1,
|
||||
)
|
||||
|
||||
T.wait_wgmma(0)
|
||||
|
||||
if i_i != 0:
|
||||
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
|
||||
|
||||
T.copy(m_i, m_i_prev)
|
||||
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.exp2(
|
||||
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||
)
|
||||
T.reduce_sum(
|
||||
acc_s, sumexp_i, dim=1
|
||||
) # is this a accumulate operator?
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_l[h_i, d_i] *= alpha_local[h_i]
|
||||
T.copy(alpha_local, alpha_shared)
|
||||
|
||||
T.copy(acc_s, S_shared)
|
||||
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
|
||||
|
||||
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||
T.barrier_arrive(bar_k_0_free[0])
|
||||
|
||||
# Buffer 1
|
||||
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
|
||||
T.barrier_arrive(bar_0_128)
|
||||
T.barrier_wait(bar_0_128, 1)
|
||||
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.if_then_else(
|
||||
is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1
|
||||
)
|
||||
T.gemm(
|
||||
Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1
|
||||
)
|
||||
T.gemm(
|
||||
Q_tail_shared,
|
||||
K_tail_shared_1,
|
||||
acc_s,
|
||||
transpose_B=True,
|
||||
wg_wait=-1,
|
||||
)
|
||||
|
||||
T.wait_wgmma(0)
|
||||
|
||||
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
|
||||
|
||||
T.copy(m_i, m_i_prev)
|
||||
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||
acc_s[h_i, bi_i] = T.exp2(
|
||||
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||
)
|
||||
T.reduce_sum(
|
||||
acc_s, sumexp_i, dim=1
|
||||
) # is this a accumulate operator?
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_l[h_i, d_i] *= alpha_local[h_i]
|
||||
T.copy(alpha_local, alpha_shared)
|
||||
|
||||
T.copy(acc_s, S_shared)
|
||||
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
|
||||
|
||||
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||
T.barrier_arrive(bar_k_1_free[0])
|
||||
|
||||
# Rescale
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sum_exp_shared[h_i] = sumexp[h_i]
|
||||
T.barrier_arrive(bar_final)
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_l[h_i, d_i] /= sumexp[h_i]
|
||||
for h_i in T.Parallel(H_per_block):
|
||||
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
|
||||
T.copy(acc_o_l, O_shared_l)
|
||||
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
|
||||
elif tx >= 128 and tx < 256:
|
||||
# T.set_max_nreg(168, 1)
|
||||
T.fill(acc_o_r, 0)
|
||||
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||
# Buffer 0
|
||||
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
|
||||
T.barrier_arrive(bar_1_128)
|
||||
T.barrier_wait(bar_1_128, 0)
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
|
||||
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
|
||||
T.barrier_arrive(bar_k_0_free[0])
|
||||
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||
|
||||
# Buffer 1
|
||||
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
|
||||
T.barrier_arrive(bar_1_128)
|
||||
T.barrier_wait(bar_1_128, 1)
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
|
||||
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
|
||||
T.barrier_arrive(bar_k_1_free[0])
|
||||
if i_i != T.ceildiv(NI, 2) - 1:
|
||||
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||
|
||||
# Rescale
|
||||
T.barrier_wait(bar_final, 0)
|
||||
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
|
||||
|
||||
T.copy(acc_o_r, O_shared_r)
|
||||
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
|
||||
elif tx >= 256:
|
||||
# producer
|
||||
T.set_max_nreg(80, 0)
|
||||
indices_local[0] = 0
|
||||
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||
# Buffer 0
|
||||
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
|
||||
T.barrier_arrive(bar_2_128)
|
||||
T.barrier_wait(bar_2_128, 0)
|
||||
|
||||
for r in T.serial(4):
|
||||
indices_tmp[0] = Indices[
|
||||
b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8
|
||||
]
|
||||
is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
|
||||
if is_kv_valid_0[r * 16 + (tx - 256) // 8]:
|
||||
indices_local[0] = indices_tmp[0]
|
||||
|
||||
with T.attr("default", "async_scope", 1): # type: ignore
|
||||
for u in T.serial(4):
|
||||
for v in T.vectorized(8):
|
||||
KV_shared_0_l[
|
||||
r * 16 + (tx - 256) // 8,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
KV_shared_0_r[
|
||||
r * 16 + (tx - 256) // 8,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
with T.attr("default", "async_scope", 1): # type: ignore
|
||||
for v in T.vectorized(8):
|
||||
K_tail_shared_0[
|
||||
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
D + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
|
||||
T.cp_async_barrier_noinc(bar_k_0_ready[0])
|
||||
|
||||
# Buffer 1
|
||||
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
|
||||
T.barrier_arrive(bar_2_128)
|
||||
T.barrier_wait(bar_2_128, 1)
|
||||
|
||||
for r in T.serial(4):
|
||||
indices_tmp[0] = Indices[
|
||||
b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8
|
||||
]
|
||||
is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
|
||||
if is_kv_valid_1[r * 16 + (tx - 256) // 8]:
|
||||
indices_local[0] = indices_tmp[0]
|
||||
|
||||
with T.attr("default", "async_scope", 1): # type: ignore
|
||||
for u in T.serial(4):
|
||||
for v in T.vectorized(8):
|
||||
KV_shared_1_l[
|
||||
r * 16 + (tx - 256) // 8,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
KV_shared_1_r[
|
||||
r * 16 + (tx - 256) // 8,
|
||||
64 * u + (tx - 256) % 8 * 8 + v,
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
with T.attr("default", "async_scope", 1): # type: ignore
|
||||
for v in T.vectorized(8):
|
||||
K_tail_shared_1[
|
||||
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
|
||||
] = KV[
|
||||
b_i,
|
||||
indices_local[0],
|
||||
g_i,
|
||||
D + (tx - 256) % 8 * 8 + v,
|
||||
]
|
||||
|
||||
T.cp_async_barrier_noinc(bar_k_1_ready[0])
|
||||
|
||||
return main
|
||||
|
||||
def tilelang_sparse_fwd(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> torch.Tensor:
|
||||
assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3
|
||||
num_heads = q.shape[1]
|
||||
dim = q.shape[2]
|
||||
tail_dim = dim - d_v
|
||||
topk = indices.shape[-1]
|
||||
assert topk == 2048
|
||||
# NOTE(dark): v2 offers better performance than v1
|
||||
kernel = sparse_attention_fwd_kernel_v2(
|
||||
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
|
||||
)
|
||||
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
|
||||
'''
|
||||
def act_quant(
|
||||
x: torch.Tensor,
|
||||
block_size: int = 128,
|
||||
scale_fmt: Optional[str] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
PyTorch fallback for act_quant
|
||||
Block-wise FP8 E4M3 quantization
|
||||
"""
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
|
||||
N = x.size(-1)
|
||||
assert N % block_size == 0, f"Last dim {N} must be divisible by block_size={block_size}"
|
||||
|
||||
# Reshape to blocks
|
||||
x_2d = x.view(-1, N)
|
||||
x_blocks = x_2d.view(-1, block_size)
|
||||
|
||||
# Compute absmax per block
|
||||
amax = x_blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-4)
|
||||
|
||||
# FP8 E4M3 max value is ~448
|
||||
fp8_max = 448.0
|
||||
scale = amax / fp8_max
|
||||
|
||||
if scale_fmt is not None:
|
||||
# Simulate rounded scale (power-of-2 rounding)
|
||||
scale = torch.round(scale * 256) / 256
|
||||
|
||||
# Quantize and clamp
|
||||
y_blocks = torch.clamp(torch.round(x_blocks / scale), -fp8_max, fp8_max)
|
||||
|
||||
# Convert to FP8
|
||||
q = y_blocks.view_as(x_2d).to(torch.float8_e4m3fn)
|
||||
|
||||
# Reshape scale
|
||||
s = scale.view(x_2d.size(0), N // block_size).to(torch.float32)
|
||||
s = s.view(*x.shape[:-1], N // block_size)
|
||||
|
||||
return q.view_as(x), s
|
||||
65
python/sglang/srt/layers/attention/nsa/topk.py
Normal file
65
python/sglang/srt/layers/attention/nsa/topk.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import align
|
||||
|
||||
# NOTE(dark): flashmla P requires `params.topk % (2*B_TOPK) == 0`,
|
||||
# where `B_TOPK=64`. So we align to 128 by default.
|
||||
|
||||
_TOPK_ALIGNMENT = 128
|
||||
|
||||
|
||||
# TODO(dark): maybe this torch_op can support torch.compile
|
||||
def _fast_topk_torch(
|
||||
input: torch.Tensor, seq_lens: torch.Tensor, topk: int, alignment: int
|
||||
) -> torch.Tensor:
|
||||
# Fallback to torch.topk
|
||||
bs, max_seq_len = input.shape
|
||||
assert len(seq_lens) == bs
|
||||
# set those out-of-bound input to -inf
|
||||
padded_max_seq_len = align(max_seq_len, alignment)
|
||||
positions = torch.arange(
|
||||
padded_max_seq_len, device=input.device, dtype=seq_lens.dtype
|
||||
)
|
||||
positions = positions.unsqueeze(0).expand(bs, -1)
|
||||
mask = positions >= seq_lens.unsqueeze(1)
|
||||
|
||||
# NOTE(dark): just return all valid indices as an optimization
|
||||
if padded_max_seq_len <= topk:
|
||||
return positions.masked_fill(mask, -1)
|
||||
|
||||
assert topk % alignment == 0
|
||||
|
||||
# in-place operation: mask invalid inputs to -inf
|
||||
input = input.masked_fill_(mask[:, :max_seq_len], float("-inf"))
|
||||
result = input.topk(topk, dim=-1, sorted=True)
|
||||
return result.indices.masked_fill_(mask[:, :topk], -1)
|
||||
|
||||
|
||||
def fast_topk_impl(
|
||||
input: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
topk: int,
|
||||
alignment: int = _TOPK_ALIGNMENT,
|
||||
) -> torch.Tensor:
|
||||
return _fast_topk_torch(input, seq_lens, topk, alignment)
|
||||
|
||||
|
||||
def fast_topk_transform_fused_cuda(
|
||||
input: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
topk: int,
|
||||
dst_page_table: torch.Tensor,
|
||||
src_page_table: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
alignment: int = _TOPK_ALIGNMENT,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.attention.nsa.cuda import fast_topk_transform
|
||||
|
||||
assert topk == 2048 and topk % alignment == 0
|
||||
return fast_topk_transform(
|
||||
score=input,
|
||||
lengths=seq_lens,
|
||||
dst_page_table=dst_page_table,
|
||||
src_page_table=src_page_table,
|
||||
cu_seqlens=cu_seqlens_q,
|
||||
)
|
||||
144
python/sglang/srt/layers/attention/nsa/transform_index.py
Normal file
144
python/sglang/srt/layers/attention/nsa/transform_index.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def transform_index_page_table_prefill(**kwargs):
|
||||
return transform_index_page_table_prefill_ref(**kwargs)
|
||||
|
||||
|
||||
def transform_index_page_table_decode(**kwargs):
|
||||
return transform_index_page_table_decode_ref(**kwargs)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def transform_index_page_table_decode_kernel(
|
||||
page_table_ptr: torch.Tensor,
|
||||
topk_indices_ptr: torch.Tensor,
|
||||
result_ptr: torch.Tensor,
|
||||
page_size: tl.constexpr,
|
||||
max_seqlen_k: tl.constexpr,
|
||||
):
|
||||
TOPK: tl.constexpr = 2048
|
||||
req_id = tl.program_id(0)
|
||||
page_table_ptr = page_table_ptr + req_id * max_seqlen_k
|
||||
topk_indices_ptr = topk_indices_ptr + req_id * TOPK
|
||||
result_ptr = result_ptr + req_id * TOPK
|
||||
|
||||
offset = tl.arange(0, TOPK) # topk should be 2048
|
||||
loaded_topk_indices = tl.load(topk_indices_ptr + offset)
|
||||
mask = loaded_topk_indices >= 0
|
||||
loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
|
||||
tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
|
||||
tl.store(result_ptr + offset, -1, mask=~mask)
|
||||
|
||||
|
||||
def transform_index_page_table_decode_fast(
|
||||
page_table: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
result: Optional[torch.Tensor] = None,
|
||||
page_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform the page table according to topk indices for sparse topk attention.
|
||||
Args:
|
||||
page_table: [qo_len, max_seqlen_k], the original page table
|
||||
topk_indices: [qo_len, topk], the topk indices for each query position
|
||||
Returns:
|
||||
transformed_page_table: [qo_len, topk], the transformed page table
|
||||
For out-of-bound indices in topk_indices, this should be filled with -1.
|
||||
"""
|
||||
assert page_size == 1
|
||||
assert page_table.shape[0] == topk_indices.shape[0]
|
||||
assert topk_indices.shape[1] == 2048
|
||||
qo_len = topk_indices.shape[0]
|
||||
max_seqlen_k = page_table.shape[1]
|
||||
if result is None:
|
||||
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||
# Launch triton kernel
|
||||
grid = (qo_len,)
|
||||
transform_index_page_table_decode_kernel[grid](
|
||||
page_table,
|
||||
topk_indices,
|
||||
result,
|
||||
page_size,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def transform_index_page_table_prefill_fast(
|
||||
page_table: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
extend_lens_cpu: List[int],
|
||||
page_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
# TODO(baizhou): can be implemented with another triton kernel
|
||||
assert page_size == 1
|
||||
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||
assert len(extend_lens_cpu) == page_table.shape[0]
|
||||
offset = 0
|
||||
for i, l in enumerate(extend_lens_cpu):
|
||||
transform_index_page_table_decode_fast(
|
||||
page_table[i].unsqueeze(0).expand(l, -1),
|
||||
topk_indices[offset : offset + l],
|
||||
result=result[offset : offset + l],
|
||||
)
|
||||
offset += l
|
||||
assert offset == topk_indices.shape[0]
|
||||
return result
|
||||
|
||||
|
||||
def transform_index_page_table_decode_ref(
|
||||
page_table: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
result: Optional[torch.Tensor] = None,
|
||||
page_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
assert page_size == 1
|
||||
assert page_table.shape[0] == topk_indices.shape[0]
|
||||
if result is None:
|
||||
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||
assert result.shape == topk_indices.shape
|
||||
torch.gather(
|
||||
page_table,
|
||||
dim=1,
|
||||
index=topk_indices.clamp(min=0).long(),
|
||||
out=result,
|
||||
)
|
||||
result[topk_indices < 0] = -1
|
||||
return result
|
||||
|
||||
|
||||
def transform_index_page_table_prefill_ref(
|
||||
page_table: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
extend_lens_cpu: List[int],
|
||||
page_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
assert page_size == 1
|
||||
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||
assert len(extend_lens_cpu) == page_table.shape[0]
|
||||
offset = 0
|
||||
for i, l in enumerate(extend_lens_cpu):
|
||||
transform_index_page_table_decode_ref(
|
||||
page_table[i].unsqueeze(0).expand(l, -1),
|
||||
topk_indices[offset : offset + l],
|
||||
result=result[offset : offset + l],
|
||||
)
|
||||
offset += l
|
||||
assert offset == topk_indices.shape[0]
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bs, topk, max_seqlen = 10, 2048, 3000
|
||||
page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda")
|
||||
topk_indices = torch.full((bs, topk), -1, device="cuda")
|
||||
topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)
|
||||
ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)
|
||||
result = transform_index_page_table_decode_fast(page_table, topk_indices)
|
||||
assert torch.all(result == ref_result)
|
||||
print("Passed")
|
||||
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
|
||||
super().__init__()
|
||||
self.weights_proj = nn.Linear(d_in, 1024)
|
||||
self.n_heads = n_heads
|
||||
self.softmax_scale = softmax_scale
|
||||
|
||||
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||
weights = self.weights_proj(x)
|
||||
weights = weights * self.n_heads**-0.5
|
||||
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
||||
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
||||
return weights
|
||||
|
||||
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||
weights = self.weights_proj(x)
|
||||
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
||||
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
|
||||
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
|
||||
return weights
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(0)
|
||||
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
|
||||
x = torch.randn(128, 2048) # batch=128, d_in=2048
|
||||
q_scale = torch.randn(128, 1)
|
||||
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
for _ in range(1000):
|
||||
out_orig = model._get_logits_head_gate_orig(x, q_scale)
|
||||
print("Original version time:", time.time() - start)
|
||||
|
||||
start = time.time()
|
||||
for _ in range(1000):
|
||||
out_opt = model._get_logits_head_gate_opt(x, q_scale)
|
||||
print("Optimized version time:", time.time() - start)
|
||||
|
||||
print("Difference:", (out_orig - out_opt).abs().max().item())
|
||||
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
"""
|
||||
Original version time: 0.49235057830810547
|
||||
Optimized version time: 0.4087331295013428
|
||||
Difference: 1.4901161193847656e-08
|
||||
"""
|
||||
32
python/sglang/srt/layers/attention/nsa/utils.py
Normal file
32
python/sglang/srt/layers/attention/nsa/utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# temp NSA debugging environ
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
|
||||
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
|
||||
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
|
||||
|
||||
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var(
|
||||
"SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true"
|
||||
)
|
||||
NSA_KV_CACHE_STORE_FP8 = get_bool_env_var("SGLANG_NSA_KV_CACHE_STORE_FP8", "false")
|
||||
NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "false")
|
||||
NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "false")
|
||||
|
||||
|
||||
def _print_bool_env_vars():
|
||||
msg = ""
|
||||
for k, v in globals().items():
|
||||
if k.startswith("NSA_") and isinstance(v, bool):
|
||||
msg += f"{k}={v} "
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
_print_bool_env_vars()
|
||||
|
||||
|
||||
if not NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8:
|
||||
assert not NSA_KV_CACHE_STORE_FP8
|
||||
|
||||
|
||||
def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):
|
||||
return original_seq_lens.clamp(max=nsa_index_topk)
|
||||
869
python/sglang/srt/layers/attention/nsa_backend.py
Normal file
869
python/sglang/srt/layers/attention/nsa_backend.py
Normal file
@@ -0,0 +1,869 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeAlias,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache
|
||||
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
||||
from sglang.srt.layers.attention.nsa.topk import (
|
||||
fast_topk_impl,
|
||||
fast_topk_transform_fused_cuda,
|
||||
)
|
||||
from sglang.srt.layers.attention.nsa.transform_index import (
|
||||
transform_index_page_table_decode,
|
||||
transform_index_page_table_prefill,
|
||||
)
|
||||
from sglang.srt.layers.attention.nsa.utils import (
|
||||
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
NSA_FUSE_TOPK,
|
||||
NSA_KV_CACHE_STORE_FP8,
|
||||
compute_nsa_seqlens,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.two_batch_overlap import global_server_args_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NSAFlashMLAMetadata:
|
||||
"""Metadata only needed by FlashMLA"""
|
||||
|
||||
flashmla_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
def slice(self, sli):
|
||||
return NSAFlashMLAMetadata(
|
||||
flashmla_metadata=self.flashmla_metadata,
|
||||
num_splits=self.num_splits[sli],
|
||||
)
|
||||
|
||||
def copy_(self, other: "NSAFlashMLAMetadata"):
|
||||
self.flashmla_metadata.copy_(other.flashmla_metadata)
|
||||
self.num_splits.copy_(other.num_splits)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NSAMetadata:
|
||||
page_size: int
|
||||
|
||||
# Sequence lengths for the forward batch
|
||||
cache_seqlens_int32: torch.Tensor
|
||||
# Maximum sequence length for query
|
||||
max_seq_len_q: int
|
||||
# Maximum sequence length for key
|
||||
max_seq_len_k: int
|
||||
# Cumulative sequence lengths for query
|
||||
cu_seqlens_q: torch.Tensor
|
||||
# Cumulative sequence lengths for key
|
||||
cu_seqlens_k: torch.Tensor
|
||||
# Page table, the index of KV Cache Tables/Blocks
|
||||
# this table is always with page_size = 1
|
||||
page_table_1: torch.Tensor
|
||||
|
||||
# NOTE(dark): This will property be used in:
|
||||
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
|
||||
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
|
||||
real_page_table: torch.Tensor
|
||||
|
||||
# NSA metadata (nsa prefill are expanded)
|
||||
nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
|
||||
nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
|
||||
nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
|
||||
nsa_extend_seq_lens_list: List[int]
|
||||
nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
|
||||
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
|
||||
|
||||
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NSAIndexerMetadata(BaseIndexerMetadata):
|
||||
attn_metadata: NSAMetadata
|
||||
|
||||
def get_seqlens_int32(self) -> torch.Tensor:
|
||||
return self.attn_metadata.cache_seqlens_int32
|
||||
|
||||
def get_page_table_64(self) -> torch.Tensor:
|
||||
return self.attn_metadata.real_page_table
|
||||
|
||||
def get_seqlens_expanded(self) -> torch.Tensor:
|
||||
return self.attn_metadata.nsa_seqlens_expanded
|
||||
|
||||
def topk_transform(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
if not NSA_FUSE_TOPK:
|
||||
return fast_topk_impl(logits, self.get_seqlens_expanded(), topk)
|
||||
|
||||
# NOTE(dark): if fused, we return a transformed page table directly
|
||||
dst_page_table = torch.empty(
|
||||
(logits.shape[0], topk), dtype=torch.int32, device=logits.device
|
||||
)
|
||||
fast_topk_transform_fused_cuda(
|
||||
input=logits,
|
||||
seq_lens=self.get_seqlens_expanded(),
|
||||
topk=topk,
|
||||
dst_page_table=dst_page_table,
|
||||
src_page_table=self.attn_metadata.page_table_1,
|
||||
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
|
||||
)
|
||||
return dst_page_table
|
||||
|
||||
|
||||
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
|
||||
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
|
||||
return torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
|
||||
|
||||
_NSA_IMPL_T: TypeAlias = Literal[
|
||||
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
|
||||
]
|
||||
|
||||
NSA_PREFILL_IMPL: _NSA_IMPL_T
|
||||
NSA_DECODE_IMPL: _NSA_IMPL_T
|
||||
|
||||
|
||||
class NativeSparseAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.forward_metadata: NSAMetadata
|
||||
self.device = model_runner.device
|
||||
assert isinstance(model_runner.page_size, int)
|
||||
self.real_page_size = model_runner.page_size
|
||||
self.num_splits = (
|
||||
1 if model_runner.server_args.enable_deterministic_inference else 0
|
||||
)
|
||||
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
|
||||
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
|
||||
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.num_q_heads = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
|
||||
|
||||
assert model_runner.req_to_token_pool is not None
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
|
||||
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
|
||||
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
|
||||
|
||||
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
|
||||
|
||||
def get_device_int32_arange(self, l: int) -> torch.Tensor:
|
||||
if l > len(self._arange_buf):
|
||||
next_pow_of_2 = 1 << (l - 1).bit_length()
|
||||
self._arange_buf = torch.arange(
|
||||
next_pow_of_2, device=self.device, dtype=torch.int32
|
||||
)
|
||||
return self._arange_buf[:l]
|
||||
|
||||
def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
|
||||
page_size = self.real_page_size
|
||||
if page_size == 1:
|
||||
return page_table
|
||||
max_seqlen_k = page_table.shape[1]
|
||||
strided_indices = torch.arange(
|
||||
0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
|
||||
)
|
||||
return page_table[:, strided_indices] // page_size
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init the metadata for a forward pass."""
|
||||
batch_size = forward_batch.batch_size
|
||||
device = forward_batch.seq_lens.device
|
||||
|
||||
assert (
|
||||
forward_batch.spec_info is None
|
||||
), "Spec decoding is not supported for NSA backend now"
|
||||
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
assert forward_batch.seq_lens_cpu is not None
|
||||
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
|
||||
page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, :max_seqlen_k
|
||||
]
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
extend_seq_lens_cpu = [1] * batch_size
|
||||
max_seqlen_q = 1
|
||||
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
|
||||
seqlens_expanded = cache_seqlens_int32
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
assert (
|
||||
forward_batch.extend_seq_lens_cpu is not None
|
||||
and forward_batch.extend_seq_lens is not None
|
||||
and forward_batch.extend_prefix_lens_cpu is not None
|
||||
), "All of them must not be None"
|
||||
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
|
||||
assert forward_batch.extend_seq_lens is not None
|
||||
if any(forward_batch.extend_prefix_lens_cpu):
|
||||
max_seqlen_q = max(extend_seq_lens_cpu)
|
||||
cu_seqlens_q = compute_cu_seqlens(
|
||||
forward_batch.extend_seq_lens.to(torch.int32)
|
||||
)
|
||||
else:
|
||||
max_seqlen_q = max_seqlen_k
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
seqlens_expanded = torch.cat(
|
||||
[
|
||||
torch.arange(
|
||||
kv_len - qo_len + 1,
|
||||
kv_len + 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
for qo_len, kv_len in zip(
|
||||
forward_batch.extend_seq_lens_cpu,
|
||||
forward_batch.seq_lens_cpu.tolist(),
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported {forward_batch.forward_mode = }"
|
||||
|
||||
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
|
||||
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||
original_seq_lens=seqlens_expanded,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
)
|
||||
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
||||
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
||||
|
||||
metadata = NSAMetadata(
|
||||
page_size=self.real_page_size,
|
||||
cache_seqlens_int32=cache_seqlens_int32,
|
||||
max_seq_len_q=max_seqlen_q,
|
||||
max_seq_len_k=max_seqlen_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
page_table_1=page_table,
|
||||
flashmla_metadata=(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||
else None
|
||||
),
|
||||
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
||||
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
||||
nsa_seqlens_expanded=seqlens_expanded,
|
||||
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
|
||||
real_page_table=self._transform_table_1_to_real(page_table),
|
||||
)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
"""Initialize CUDA graph state for the attention backend.
|
||||
|
||||
Args:
|
||||
max_bs (int): Maximum batch size to support in CUDA graphs
|
||||
|
||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||
to avoid memory allocations.
|
||||
"""
|
||||
self.decode_cuda_graph_metadata: Dict = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
# fake page_table for sparse_prefill
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
self.max_context_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"flashmla_metadata": (
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=torch.ones(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
"""Initialize forward metadata for capturing CUDA graph."""
|
||||
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
||||
assert (
|
||||
spec_info is None
|
||||
), "Speculative decoding is not supported for NSA backend now"
|
||||
|
||||
# Normal Decode
|
||||
# Get sequence information
|
||||
cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||
|
||||
# Use max context length for seq_len_k
|
||||
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
max_seq_len_k = page_table_1.shape[1]
|
||||
|
||||
# Precompute page table
|
||||
# Precompute cumulative sequence lengths
|
||||
|
||||
# NOTE(dark): this is always arange, since we are decoding
|
||||
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
|
||||
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
|
||||
)
|
||||
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
||||
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
||||
real_page_table = self._transform_table_1_to_real(page_table_1)
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||
"flashmla_metadata"
|
||||
].slice(slice(0, bs + 1))
|
||||
flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens_int32,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
)
|
||||
else:
|
||||
flashmla_metadata = None
|
||||
|
||||
metadata = NSAMetadata(
|
||||
page_size=self.real_page_size,
|
||||
cache_seqlens_int32=cache_seqlens_int32,
|
||||
max_seq_len_q=1,
|
||||
max_seq_len_k=max_seq_len_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
page_table_1=page_table_1,
|
||||
flashmla_metadata=flashmla_metadata,
|
||||
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
||||
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
||||
nsa_seqlens_expanded=cache_seqlens_int32,
|
||||
real_page_table=real_page_table,
|
||||
nsa_extend_seq_lens_list=[1] * bs,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
out_cache_loc: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Initialize forward metadata for replaying CUDA graph."""
|
||||
assert seq_lens_cpu is not None
|
||||
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
||||
assert (
|
||||
spec_info is None
|
||||
), "Speculative decoding is not supported for NSA backend now"
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
|
||||
# Normal Decode
|
||||
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = int(seq_lens_cpu.max().item())
|
||||
|
||||
cache_seqlens = seq_lens.to(torch.int32)
|
||||
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
page_indices = self.req_to_token[req_pool_indices, :max_len]
|
||||
metadata.page_table_1[:, :max_len].copy_(page_indices)
|
||||
assert (
|
||||
metadata.nsa_cache_seqlens_int32 is not None
|
||||
and metadata.nsa_cu_seqlens_k is not None
|
||||
and self.nsa_index_topk is not None
|
||||
)
|
||||
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
|
||||
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
||||
metadata.nsa_cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
|
||||
)
|
||||
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
|
||||
|
||||
assert self.real_page_size == metadata.page_size
|
||||
if self.real_page_size > 1:
|
||||
real_table = self._transform_table_1_to_real(page_indices)
|
||||
new_len = real_table.shape[1]
|
||||
metadata.real_page_table[:, :new_len].copy_(real_table)
|
||||
else:
|
||||
assert metadata.real_page_table is metadata.page_table_1
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||
metadata.flashmla_metadata.copy_(
|
||||
self._compute_flashmla_metadata(
|
||||
cache_seqlens=nsa_cache_seqlens,
|
||||
seq_len_q=1, # TODO handle MTP which is not 1
|
||||
)
|
||||
)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
), "NSA backend doesn't support speculative decoding"
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
k_rope,
|
||||
)
|
||||
|
||||
metadata = self.forward_metadata
|
||||
causal = not layer.is_cross_attention
|
||||
assert causal, "NSA is causal only"
|
||||
|
||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||
kwargs = {}
|
||||
|
||||
# Do absorbed multi-latent attention
|
||||
assert q_rope is not None
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
|
||||
# when store in fp8 and compute in fp8, no need to convert dtype
|
||||
if not (NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and NSA_KV_CACHE_STORE_FP8):
|
||||
kv_cache = kv_cache.to(q.dtype)
|
||||
|
||||
if q_rope is not None:
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
else:
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
|
||||
# NOTE(dark): here, we use page size = 1
|
||||
|
||||
if NSA_FUSE_TOPK:
|
||||
page_table_1 = topk_indices
|
||||
else:
|
||||
assert metadata.nsa_extend_seq_lens_list is not None
|
||||
page_table_1 = transform_index_page_table_prefill(
|
||||
page_table=metadata.page_table_1,
|
||||
topk_indices=topk_indices,
|
||||
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
|
||||
page_size=1,
|
||||
)
|
||||
# if NSA_PREFILL_IMPL == "tilelang":
|
||||
# from sglang.srt.layers.attention.nsa.tilelang_kernel import (
|
||||
# tilelang_sparse_fwd,
|
||||
# )
|
||||
|
||||
# if q_rope is not None:
|
||||
# q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
# return self._forward_tilelang(
|
||||
# q_all=q_all,
|
||||
# kv_cache=kv_cache,
|
||||
# page_table_1=page_table_1,
|
||||
# sm_scale=layer.scaling,
|
||||
# v_head_dim=layer.v_head_dim,
|
||||
# )
|
||||
# elif NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||
|
||||
|
||||
# Skip tilelang dependencies
|
||||
if NSA_PREFILL_IMPL == "tilelang" or NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_prefill(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_PREFILL_IMPL == "flashmla_decode":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_decode(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
# TODO optimize args
|
||||
layer=layer,
|
||||
forward_batch=forward_batch,
|
||||
metadata=metadata,
|
||||
topk_indices=topk_indices,
|
||||
block_table=metadata.real_page_table,
|
||||
)
|
||||
elif NSA_PREFILL_IMPL == "fa3":
|
||||
return self._forward_fa3(
|
||||
q_rope=q_rope,
|
||||
kv_cache=kv_cache,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
q_nope=q_nope,
|
||||
page_table=page_table_1,
|
||||
cache_seqlens=metadata.nsa_cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
|
||||
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
|
||||
max_seqlen_q=metadata.nsa_max_seqlen_q,
|
||||
sm_scale=layer.scaling,
|
||||
logit_cap=layer.logit_cap,
|
||||
page_size=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
topk_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
cache_loc = (
|
||||
forward_batch.out_cache_loc
|
||||
if not layer.is_cross_attention
|
||||
else forward_batch.encoder_out_cache_loc
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
k_rope,
|
||||
)
|
||||
|
||||
metadata = self.forward_metadata
|
||||
causal = not layer.is_cross_attention
|
||||
assert causal, "NSA is causal only"
|
||||
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
if q_rope is not None:
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
else:
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
|
||||
if NSA_FUSE_TOPK:
|
||||
page_table_1 = topk_indices
|
||||
else:
|
||||
page_table_1 = transform_index_page_table_decode(
|
||||
page_table=metadata.page_table_1,
|
||||
topk_indices=topk_indices,
|
||||
page_size=1,
|
||||
)
|
||||
|
||||
if NSA_DECODE_IMPL == "flashmla_prefill":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_prefill(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_DECODE_IMPL == "flashmla_decode":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_flashmla_decode(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
# TODO optimize args
|
||||
layer=layer,
|
||||
forward_batch=forward_batch,
|
||||
metadata=metadata,
|
||||
topk_indices=topk_indices,
|
||||
block_table=metadata.real_page_table,
|
||||
)
|
||||
elif NSA_DECODE_IMPL == "tilelang":
|
||||
if q_rope is not None:
|
||||
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||
return self._forward_tilelang(
|
||||
q_all=q_all,
|
||||
kv_cache=kv_cache,
|
||||
page_table_1=page_table_1,
|
||||
sm_scale=layer.scaling,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
)
|
||||
elif NSA_DECODE_IMPL == "fa3":
|
||||
return self._forward_fa3(
|
||||
q_rope=q_rope,
|
||||
kv_cache=kv_cache,
|
||||
v_head_dim=layer.v_head_dim,
|
||||
q_nope=q_nope,
|
||||
page_table=page_table_1,
|
||||
cache_seqlens=metadata.nsa_cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
|
||||
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
|
||||
max_seqlen_q=metadata.nsa_max_seqlen_q,
|
||||
sm_scale=layer.scaling,
|
||||
logit_cap=layer.logit_cap,
|
||||
page_size=1,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported {NSA_DECODE_IMPL = }"
|
||||
|
||||
def _forward_fa3(
|
||||
self,
|
||||
q_rope: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
v_head_dim: int,
|
||||
q_nope: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
sm_scale: float,
|
||||
logit_cap: float,
|
||||
page_size: int,
|
||||
) -> torch.Tensor:
|
||||
k_rope_cache = kv_cache[:, :, v_head_dim:]
|
||||
c_kv_cache = kv_cache[:, :, :v_head_dim]
|
||||
qk_rope_dim = k_rope_cache.shape[-1]
|
||||
k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
|
||||
c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k_new=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
softcap=logit_cap,
|
||||
return_softmax_lse=False,
|
||||
num_splits=self.num_splits,
|
||||
)
|
||||
return o # type: ignore
|
||||
|
||||
def _forward_flashmla_prefill(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
v_head_dim: int,
|
||||
page_table_1: torch.Tensor,
|
||||
sm_scale: float,
|
||||
) -> torch.Tensor:
|
||||
#from flash_mla import flash_mla_sparse_fwd
|
||||
from sglang.srt.layers.attention.native_mla import native_mla_sparse_fwd
|
||||
_, _, o = native_mla_sparse_fwd(
|
||||
q=q_all,
|
||||
kv=kv_cache,
|
||||
indices=page_table_1.unsqueeze(1),
|
||||
sm_scale=sm_scale,
|
||||
d_v=v_head_dim,
|
||||
)
|
||||
return o
|
||||
|
||||
def _forward_flashmla_decode(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
v_head_dim: int,
|
||||
sm_scale: float,
|
||||
layer,
|
||||
forward_batch: ForwardBatch,
|
||||
metadata: NSAMetadata,
|
||||
topk_indices,
|
||||
block_table,
|
||||
) -> torch.Tensor:
|
||||
#from flash_mla import flash_mla_with_kvcache
|
||||
from sglang.srt.layers.attention.native_mla import native_mla_with_kvcache
|
||||
cache_seqlens = metadata.nsa_cache_seqlens_int32
|
||||
|
||||
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
|
||||
q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
|
||||
kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
|
||||
assert self.real_page_size == 64, "only page size 64 is supported"
|
||||
|
||||
if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not NSA_KV_CACHE_STORE_FP8:
|
||||
# inefficiently quantize the whole cache
|
||||
kv_cache = quantize_k_cache(kv_cache)
|
||||
|
||||
o, _ = native_mla_with_kvcache(
|
||||
q=q_all,
|
||||
k_cache=kv_cache,
|
||||
cache_seqlens=cache_seqlens,
|
||||
head_dim_v=v_head_dim,
|
||||
tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
|
||||
num_splits=metadata.flashmla_metadata.num_splits,
|
||||
softmax_scale=sm_scale,
|
||||
# TODO improve
|
||||
indices=_compute_indices_in_kvcache(
|
||||
block_table=block_table,
|
||||
topk_indices=topk_indices.to(torch.int32),
|
||||
page_size=self.real_page_size,
|
||||
nsa_index_topk=self.nsa_index_topk,
|
||||
),
|
||||
# doc says it is not used, but if pass in None then error
|
||||
block_table=torch.empty(
|
||||
(q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
|
||||
),
|
||||
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
)
|
||||
|
||||
# TODO shape correct?
|
||||
return o
|
||||
|
||||
def _forward_tilelang(
|
||||
self,
|
||||
q_all: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
v_head_dim: int,
|
||||
page_table_1: torch.Tensor,
|
||||
sm_scale: float,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
|
||||
|
||||
return tilelang_sparse_fwd(
|
||||
q=q_all,
|
||||
kv=kv_cache,
|
||||
indices=page_table_1.unsqueeze(1),
|
||||
sm_scale=sm_scale,
|
||||
d_v=v_head_dim,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
return 1
|
||||
|
||||
def get_indexer_metadata(
|
||||
self, layer_id: int, forward_batch: ForwardBatch
|
||||
) -> NSAIndexerMetadata:
|
||||
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
|
||||
|
||||
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
|
||||
from flash_mla import get_mla_metadata
|
||||
|
||||
flashmla_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=cache_seqlens,
|
||||
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
|
||||
# but the name looks like need seq_len_q?
|
||||
num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
|
||||
num_heads_k=1,
|
||||
num_heads_q=self.num_q_heads,
|
||||
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||
topk=self.nsa_index_topk,
|
||||
)
|
||||
|
||||
return NSAFlashMLAMetadata(
|
||||
flashmla_metadata=flashmla_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
# TODO speedup
|
||||
def _compute_indices_in_kvcache(block_table, topk_indices, page_size, nsa_index_topk):
|
||||
topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
|
||||
|
||||
idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
|
||||
1
|
||||
)
|
||||
block_idx = block_table[idx0, topk_indices_safe // page_size]
|
||||
offset = topk_indices_safe % page_size
|
||||
indices_in_kvcache = block_idx * page_size + offset
|
||||
|
||||
# the kernel requires invalid entry to be -1
|
||||
assert indices_in_kvcache.shape == topk_indices.shape
|
||||
indices_in_kvcache[topk_indices == -1] = -1
|
||||
|
||||
# return: (batch_size, seqlen_q_ori, topk)
|
||||
indices_in_kvcache = indices_in_kvcache[:, None, :]
|
||||
|
||||
indices_in_kvcache = torch.nn.functional.pad(
|
||||
indices_in_kvcache,
|
||||
(0, nsa_index_topk - indices_in_kvcache.shape[-1]),
|
||||
"constant",
|
||||
-1,
|
||||
)
|
||||
assert indices_in_kvcache.shape[-1] == nsa_index_topk
|
||||
|
||||
return indices_in_kvcache
|
||||
@@ -127,8 +127,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
"disable_chunked_prefix_cache"
|
||||
]
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||
"""
|
||||
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
||||
@@ -219,7 +217,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
||||
if not forward_mode.is_decode_or_idle():
|
||||
return super().init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
@@ -230,9 +228,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
spec_info,
|
||||
)
|
||||
|
||||
if forward_mode.is_target_verify():
|
||||
seq_lens = seq_lens + self.num_draft_tokens
|
||||
|
||||
# Custom fast-path for decode/idle.
|
||||
# Capture with full width so future longer sequences are safe during replay
|
||||
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
||||
@@ -275,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
):
|
||||
"""Replay CUDA graph with new inputs."""
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
||||
if not forward_mode.is_decode_or_idle():
|
||||
return super().init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
@@ -287,10 +282,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens_cpu,
|
||||
)
|
||||
|
||||
if forward_mode.is_target_verify():
|
||||
seq_lens = seq_lens + self.num_draft_tokens
|
||||
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
|
||||
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
# Update block indices for new sequences.
|
||||
@@ -341,10 +332,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
cum_seq_lens_q,
|
||||
seq_lens,
|
||||
)
|
||||
elif (
|
||||
forward_batch.forward_mode.is_decode_or_idle()
|
||||
or forward_batch.forward_mode.is_target_verify()
|
||||
):
|
||||
elif forward_batch.forward_mode.is_decode_or_idle():
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
# Get maximum sequence length.
|
||||
@@ -353,19 +341,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
else:
|
||||
max_seq = forward_batch.seq_lens.max().item()
|
||||
|
||||
seq_lens = forward_batch.seq_lens
|
||||
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
max_seq = max_seq + self.num_draft_tokens
|
||||
seq_lens = seq_lens + self.num_draft_tokens
|
||||
|
||||
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
||||
block_kv_indices = self._create_block_kv_indices(
|
||||
bs,
|
||||
max_seqlen_pad,
|
||||
forward_batch.req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens.device,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens.device,
|
||||
)
|
||||
|
||||
max_seq_len_val = int(max_seq)
|
||||
@@ -505,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
q_rope_reshaped = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
||||
if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
|
||||
query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
|
||||
else:
|
||||
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
||||
else:
|
||||
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
@@ -568,134 +553,84 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
save_kv_cache: bool = True,
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
):
|
||||
if (
|
||||
forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
|
||||
if forward_batch.attn_attend_prefix_cache is None:
|
||||
return super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
|
||||
# Save KV cache if requested
|
||||
if save_kv_cache:
|
||||
assert (
|
||||
k is not None and k_rope is not None
|
||||
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
|
||||
if q_rope is not None:
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
q = _concat_mla_absorb_q_general(q, q_rope)
|
||||
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
if k_rope is not None:
|
||||
k = torch.cat([k, k_rope], dim=-1)
|
||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
||||
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
metadata = (
|
||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||
or self.forward_decode_metadata
|
||||
)
|
||||
|
||||
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
|
||||
bs = forward_batch.batch_size
|
||||
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
||||
|
||||
q_scale = 1.0
|
||||
k_scale = (
|
||||
layer.k_scale_float
|
||||
if getattr(layer, "k_scale_float", None) is not None
|
||||
else 1.0
|
||||
)
|
||||
|
||||
bmm1_scale = q_scale * k_scale * layer.scaling
|
||||
|
||||
seq_lens = (
|
||||
forward_batch.seq_lens.to(torch.int32)
|
||||
+ forward_batch.spec_info.draft_token_num
|
||||
)
|
||||
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
|
||||
|
||||
# TODO may use `mla_rope_quantize_fp8` fusion
|
||||
q = q.to(self.data_type)
|
||||
assert kv_cache.dtype == self.data_type
|
||||
|
||||
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=metadata.block_kv_indices,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=bmm1_scale,
|
||||
)
|
||||
|
||||
# Reshape output directly without slicing
|
||||
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
return output
|
||||
|
||||
if forward_batch.attn_attend_prefix_cache:
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
assert q_rope is None
|
||||
assert k_rope is None
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
|
||||
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
||||
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
if not forward_batch.attn_attend_prefix_cache:
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
||||
seq_lens=self.forward_prefill_metadata.seq_lens,
|
||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
||||
bmm1_scale=layer.scaling,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=-1.0,
|
||||
o_sf_scale=1.0,
|
||||
batch_size=forward_batch.batch_size,
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
||||
enable_pdl=False,
|
||||
is_causal=False,
|
||||
return_lse=True,
|
||||
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
||||
is_causal=True,
|
||||
return_lse=forward_batch.mha_return_lse,
|
||||
)
|
||||
else:
|
||||
if not (
|
||||
forward_batch.attn_attend_prefix_cache is not None
|
||||
and forward_batch.mha_return_lse
|
||||
):
|
||||
output = super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
else:
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
assert q_rope is None
|
||||
assert k_rope is None
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
|
||||
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
seq_lens=self.forward_prefill_metadata.seq_lens,
|
||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
||||
bmm1_scale=layer.scaling,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=1.0,
|
||||
batch_size=forward_batch.batch_size,
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
||||
enable_pdl=False,
|
||||
is_causal=True,
|
||||
return_lse=forward_batch.mha_return_lse,
|
||||
)
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
|
||||
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
||||
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||
bmm1_scale=layer.scaling,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=-1.0,
|
||||
batch_size=forward_batch.batch_size,
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||
enable_pdl=False,
|
||||
is_causal=False,
|
||||
return_lse=True,
|
||||
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||
@@ -713,10 +648,3 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||
kv_indptr_buf=self.kv_indptr[i],
|
||||
q_indptr_decode_buf=self.q_indptr_decode,
|
||||
)
|
||||
|
||||
|
||||
def _concat_mla_absorb_q_general(q_nope, q_rope):
|
||||
if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
|
||||
return concat_mla_absorb_q(q_nope, q_rope)
|
||||
else:
|
||||
return torch.cat([q_nope, q_rope], dim=-1)
|
||||
|
||||
@@ -16,19 +16,14 @@ from sglang.srt.utils import (
|
||||
get_device_capability,
|
||||
is_blackwell,
|
||||
is_cuda,
|
||||
is_npu,
|
||||
print_info_once,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -336,63 +331,10 @@ class VisionFlash3Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class VisionAscendAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
if not _is_npu:
|
||||
raise Exception("VisionAscendAttention is only available for ascend npu")
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
||||
bsz: int,
|
||||
seq_len: int,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
||||
|
||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
if seq_lens.is_npu:
|
||||
# cu_seqlens must be on cpu because of operator restriction
|
||||
seq_lens = seq_lens.to("cpu")
|
||||
_, num_heads, head_size = q.shape
|
||||
num_kv_heads = k.shape[1]
|
||||
output = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=seq_lens.to(torch.int32),
|
||||
scale_value=head_size**-0.5,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
out=output,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
QKV_BACKEND_IMPL = {
|
||||
"triton_attn": VisionTritonAttention,
|
||||
"sdpa": VisionSdpaAttention,
|
||||
"fa3": VisionFlash3Attention,
|
||||
"ascend_attn": VisionAscendAttention,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_sm90_supported,
|
||||
is_sm100_supported,
|
||||
prepare_weight_cache,
|
||||
)
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
@@ -276,11 +275,7 @@ class LayerCommunicator:
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
cache=None,
|
||||
):
|
||||
if cache is not None:
|
||||
self._context.cache = cache
|
||||
|
||||
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
@@ -354,7 +349,6 @@ class CommunicateContext:
|
||||
attn_tp_size: int
|
||||
attn_dp_size: int
|
||||
tp_size: int
|
||||
cache = None
|
||||
|
||||
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
||||
@@ -539,8 +533,6 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
if context.cache is not None:
|
||||
_ = prepare_weight_cache(hidden_states, context.cache)
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
||||
|
||||
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
||||
assert len(x.shape) == 2
|
||||
assert (
|
||||
x.shape == residual.shape and x.dtype == residual.dtype
|
||||
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
|
||||
assert x.shape == residual.shape and x.dtype == residual.dtype
|
||||
output, mid = torch.empty_like(x), torch.empty_like(x)
|
||||
bs, hidden_dim = x.shape
|
||||
if autotune:
|
||||
|
||||
@@ -127,69 +127,34 @@ class RMSNorm(CustomOp):
|
||||
return output, residual_out
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
# def forward_hip(
|
||||
# self,
|
||||
# x: torch.Tensor,
|
||||
# residual: Optional[torch.Tensor] = None,
|
||||
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# if not x.is_contiguous():
|
||||
# # NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
# x = x.contiguous()
|
||||
# if residual is not None:
|
||||
# if _vllm_version < Version("0.9"):
|
||||
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
# return x, residual
|
||||
# else:
|
||||
# residual_out = torch.empty_like(x)
|
||||
# output = torch.empty_like(x)
|
||||
# fused_add_rms_norm(
|
||||
# output,
|
||||
# x,
|
||||
# residual_out,
|
||||
# residual,
|
||||
# self.weight.data,
|
||||
# self.variance_epsilon,
|
||||
# )
|
||||
# return output, residual_out
|
||||
# out = torch.empty_like(x)
|
||||
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
# return out
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if not x.is_contiguous():
|
||||
# NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
x = x.contiguous()
|
||||
|
||||
if residual is not None:
|
||||
try:
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
residual_out,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
#if _vllm_version < Version("0.9"):
|
||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
# else:
|
||||
# residual_out = torch.empty_like(x)
|
||||
# output = torch.empty_like(x)
|
||||
# fused_add_rms_norm(
|
||||
# output,
|
||||
# x,
|
||||
# residual_out,
|
||||
# residual,
|
||||
# self.weight.data,
|
||||
# self.variance_epsilon,
|
||||
# )
|
||||
# return output, residual_out
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -31,7 +31,6 @@ from sglang.srt.layers.parameter import (
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.utils import pad_or_narrow_weight
|
||||
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -626,16 +625,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||
end_idx = start_idx + shard_size
|
||||
if end_idx > loaded_weight.shape[output_dim]:
|
||||
loaded_weight = pad_or_narrow_weight(
|
||||
loaded_weight, output_dim, start_idx, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
output_dim, start_idx, shard_size
|
||||
)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
output_dim, start_idx, shard_size
|
||||
)
|
||||
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
@@ -1310,16 +1302,7 @@ class RowParallelLinear(LinearBase):
|
||||
shard_size,
|
||||
)
|
||||
else:
|
||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||
end_idx = start_idx + shard_size
|
||||
if end_idx > loaded_weight.shape[input_dim]:
|
||||
loaded_weight = pad_or_narrow_weight(
|
||||
loaded_weight, input_dim, start_idx, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
input_dim, start_idx, shard_size
|
||||
)
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
|
||||
@@ -220,7 +220,6 @@ class LogitsProcessor(nn.Module):
|
||||
self.config = config
|
||||
self.logit_scale = logit_scale
|
||||
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
||||
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
|
||||
if self.use_attn_tp_group:
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
@@ -462,11 +461,7 @@ class LogitsProcessor(nn.Module):
|
||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||
|
||||
if hasattr(lm_head, "weight"):
|
||||
if self.use_fp32_lm_head:
|
||||
logits = torch.matmul(
|
||||
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
|
||||
)
|
||||
elif use_intel_amx_backend(lm_head):
|
||||
if use_intel_amx_backend(lm_head):
|
||||
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
hidden_states.to(lm_head.weight.dtype),
|
||||
lm_head.weight,
|
||||
@@ -480,15 +475,7 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
# GGUF models
|
||||
# TODO: use weight_packed_linear for GGUF models
|
||||
if self.use_fp32_lm_head:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
logits = lm_head.quant_method.apply(
|
||||
lm_head, hidden_states.to(torch.float32), embedding_bias
|
||||
)
|
||||
else:
|
||||
logits = lm_head.quant_method.apply(
|
||||
lm_head, hidden_states, embedding_bias
|
||||
)
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
||||
|
||||
if self.logit_scale is not None:
|
||||
logits.mul_(self.logit_scale)
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -125,6 +124,7 @@ class EPMoE(FusedMoE):
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.block_shape = (
|
||||
@@ -135,23 +135,11 @@ class EPMoE(FusedMoE):
|
||||
self.use_fp8_w8a8 = True
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.activation_scheme = quant_config.activation_scheme
|
||||
self.use_w4a8_marlin = False
|
||||
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.block_shape = (
|
||||
self.quant_method.quant_config.weight_block_size
|
||||
if self.use_block_quant
|
||||
else None
|
||||
)
|
||||
self.use_fp8_w8a8 = False
|
||||
self.activation_scheme = None
|
||||
self.use_w4a8_marlin = True
|
||||
else:
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
self.block_shape = None
|
||||
self.activation_scheme = None
|
||||
self.use_w4a8_marlin = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
@@ -398,11 +386,11 @@ class DeepEPMoE(EPMoE):
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
# if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||
# # NPU supports low_latency deepep without deepgemm
|
||||
# assert (
|
||||
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||
# NPU supports low_latency deepep without deepgemm
|
||||
assert (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
if _use_aiter:
|
||||
# expert_mask is of size (self.num_local_experts + 1),
|
||||
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
||||
@@ -416,23 +404,23 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
# the last one is invalid rank_id
|
||||
self.expert_mask[:-1] = 1
|
||||
# elif not _is_npu:
|
||||
# self.w13_weight_fp8 = (
|
||||
# self.w13_weight,
|
||||
# (
|
||||
# self.w13_weight_scale_inv
|
||||
# if self.use_block_quant
|
||||
# else self.w13_weight_scale
|
||||
# ),
|
||||
# )
|
||||
# self.w2_weight_fp8 = (
|
||||
# self.w2_weight,
|
||||
# (
|
||||
# self.w2_weight_scale_inv
|
||||
# if self.use_block_quant
|
||||
# else self.w2_weight_scale
|
||||
# ),
|
||||
# )
|
||||
elif not _is_npu:
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
(
|
||||
self.w2_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w2_weight_scale
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -478,15 +466,8 @@ class DeepEPMoE(EPMoE):
|
||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||
return self.forward_npu(dispatch_output)
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif self.use_w4a8_marlin:
|
||||
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dispatch output is not supported"
|
||||
)
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
||||
return self.forward_flashinfer_cutedsl(dispatch_output)
|
||||
@@ -545,34 +526,6 @@ class DeepEPMoE(EPMoE):
|
||||
expert_mask=self.expert_mask,
|
||||
)
|
||||
|
||||
def forward_deepgemm_w4a8_marlin_contiguous(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
):
|
||||
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
||||
dispatch_output
|
||||
)
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
# if num_recv_tokens_per_expert is None:
|
||||
return hidden_states_int8.bfloat16()
|
||||
# expert_output = self.quant_method.apply_ep(
|
||||
# layer=self,
|
||||
# x=dispatch_output,
|
||||
# topk_weights=topk_weights,
|
||||
# topk_ids=topk_idx,
|
||||
# global_num_experts=self.global_num_experts,
|
||||
# expert_map=self.expert_map,
|
||||
# activation=self.activation,
|
||||
# apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
# use_nn_moe=self.use_nn_moe,
|
||||
# num_local_tokens=dispatch_recv_num_token,
|
||||
# config_select_bs=hidden_states.shape[0],
|
||||
# scales=dispatch_scales if self.use_int8_dispatch else None
|
||||
# # routed_scaling_factor=self.routed_scaling_factor,
|
||||
# )
|
||||
# return expert_output
|
||||
|
||||
def forward_deepgemm_contiguous(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
@@ -836,45 +789,69 @@ class DeepEPMoE(EPMoE):
|
||||
if isinstance(hidden_states, tuple):
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
else:
|
||||
# dynamic quant
|
||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
||||
hidden_states.device
|
||||
)
|
||||
if self.w13_weight.dtype != torch.int8:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||
# per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight.permute(0, 2, 1)],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
else:
|
||||
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -883,47 +860,72 @@ class DeepEPMoE(EPMoE):
|
||||
assert isinstance(dispatch_output, DeepEPLLOutput)
|
||||
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
||||
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
if isinstance(hidden_states, tuple):
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
group_list = group_list.to(torch.int64)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
if self.w13_weight.dtype != torch.int8:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||
# per_token_scale=[per_token_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight.permute(0, 2, 1)],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
else:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=per_token_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=per_token_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w2_weight],
|
||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=output_dtype,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -51,14 +51,10 @@ def get_moe_configs(
|
||||
|
||||
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
|
||||
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
|
||||
config_dir = os.environ.get(
|
||||
"SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
|
||||
)
|
||||
|
||||
triton_version = triton.__version__
|
||||
version_dir = f"triton_{triton_version.replace('.', '_')}"
|
||||
config_file_path = os.path.join(
|
||||
config_dir,
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"configs",
|
||||
version_dir,
|
||||
json_file_name,
|
||||
@@ -79,7 +75,7 @@ def get_moe_configs(
|
||||
if try_triton_version == triton_version:
|
||||
continue
|
||||
try_config_file_path = os.path.join(
|
||||
config_dir,
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"configs",
|
||||
f"triton_{try_triton_version.replace('.', '_')}",
|
||||
json_file_name,
|
||||
|
||||
@@ -575,10 +575,7 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
|
||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||
if (
|
||||
should_use_flashinfer_trtllm_moe()
|
||||
and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
||||
):
|
||||
if should_use_flashinfer_trtllm_moe():
|
||||
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||
|
||||
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
||||
|
||||
@@ -431,32 +431,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
deepep_post_reorder_triton_kernel,
|
||||
)
|
||||
|
||||
#if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
# else:
|
||||
# if hidden_states.shape[0] > 0:
|
||||
# num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
# output = torch.empty(
|
||||
# (num_tokens, hidden_states.shape[1]),
|
||||
# device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype,
|
||||
# )
|
||||
# deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
# hidden_states,
|
||||
# output,
|
||||
# self.src2dst,
|
||||
# topk_idx,
|
||||
# topk_weights,
|
||||
# self.router_topk,
|
||||
# hidden_states.shape[1],
|
||||
# BLOCK_SIZE=512,
|
||||
# )
|
||||
# else:
|
||||
# output = torch.zeros(
|
||||
# (0, hidden_states.shape[1]),
|
||||
# device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype,
|
||||
# )
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
else:
|
||||
if hidden_states.shape[0] > 0:
|
||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
output = torch.empty(
|
||||
(num_tokens, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
hidden_states,
|
||||
output,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
output = torch.zeros(
|
||||
(0, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return output, previous_event
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Callable, Optional, Union
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from sglang.srt.layers.utils import pad_or_narrow_weight
|
||||
from sglang.srt.utils import is_cpu
|
||||
|
||||
__all__ = [
|
||||
@@ -157,17 +156,9 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
)
|
||||
else:
|
||||
if not use_presharded_weights:
|
||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||
start_idx = tp_rank * shard_size
|
||||
end_idx = start_idx + shard_size
|
||||
if end_idx > loaded_weight.shape[self.output_dim]:
|
||||
loaded_weight = pad_or_narrow_weight(
|
||||
loaded_weight, self.output_dim, start_idx, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, start_idx, shard_size
|
||||
)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
@@ -267,17 +258,9 @@ class RowvLLMParameter(BasevLLMParameter):
|
||||
|
||||
return
|
||||
else:
|
||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||
start_idx = tp_rank * shard_size
|
||||
end_idx = start_idx + shard_size
|
||||
if end_idx > loaded_weight.shape[self.input_dim]:
|
||||
loaded_weight = pad_or_narrow_weight(
|
||||
loaded_weight, self.input_dim, start_idx, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, start_idx, shard_size
|
||||
)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
@@ -61,7 +61,6 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
||||
|
||||
_is_mxfp_supported = mxfp_supported()
|
||||
@@ -87,7 +86,6 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"w4afp8": W4AFp8Config,
|
||||
"petit_nvfp4": PetitNvFp4Config,
|
||||
"fbgemm_fp8": FBGEMMFp8Config,
|
||||
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
]
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per channel
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
||||
if self.input_symmetric:
|
||||
layer.input_scale = Parameter(
|
||||
layer.input_scale.max(), requires_grad=False
|
||||
)
|
||||
else:
|
||||
input_scale = layer.input_scale
|
||||
input_zero_point = layer.input_zero_point
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
|
||||
layer.input_scale = Parameter(scale, requires_grad=False)
|
||||
layer.input_zero_point = Parameter(azp, requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
layer.input_zero_point = None
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.input_symmetric:
|
||||
weight = layer.weight
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = layer.input_zero_point * azp_adj
|
||||
layer.azp_adj = Parameter(azp_adj, requires_grad=False)
|
||||
else:
|
||||
layer.azp_adj = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition, input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = PerTensorScaleParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
# TODO: add cutlass_scaled_mm_azp support
|
||||
x_q, x_scale = per_token_quant_int8(x)
|
||||
|
||||
return int8_scaled_mm(
|
||||
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -13,6 +15,7 @@ def _compute_enable_deep_gemm():
|
||||
try:
|
||||
import deep_gemm
|
||||
except ImportError:
|
||||
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
|
||||
return False
|
||||
|
||||
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
||||
|
||||
@@ -843,18 +843,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
topk_weights = topk_weights.to(
|
||||
torch.float32
|
||||
) # aiter's moe_sorting requires topk_weights to be FP32
|
||||
|
||||
if hasattr(torch, "float4_e2m1fn_x2"):
|
||||
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
||||
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
|
||||
else:
|
||||
w13_weight = layer.w13_weight
|
||||
w2_weight = layer.w2_weight
|
||||
|
||||
output = fused_moe(
|
||||
x,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type=QuantType.per_1x32,
|
||||
|
||||
@@ -183,17 +183,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
moe_runner_config = self.moe_runner_config
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
if hasattr(torch, "float4_e2m1fn_x2"):
|
||||
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
||||
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
|
||||
else:
|
||||
w13_weight = layer.w13_weight
|
||||
w2_weight = layer.w2_weight
|
||||
|
||||
output = fused_moe(
|
||||
x,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type=QuantType.per_1x32,
|
||||
|
||||
@@ -1,415 +0,0 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from sglang.srt.layers.linear import set_weight_attrs
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from torch.nn.parameter import Parameter
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import LinearMethodBase, QuantizationConfig, QuantizeMethodBase, FusedMoEMethodBase
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
_ColumnvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from lmslim.layers.gemm.int8_utils import (
|
||||
per_token_group_quant_int8,
|
||||
per_token_quant_int8)
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from vllm.utils import W8a8GetCacheJSON
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
|
||||
import os
|
||||
|
||||
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
"""
|
||||
Parameter class for linear layer weights. Uses both column and
|
||||
row parallelism.
|
||||
"""
|
||||
pass
|
||||
|
||||
W8A8_TRITONJSON=W8a8GetCacheJSON()
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
scales= scale_a* scale_b.T
|
||||
gemmout= torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
|
||||
output = (scales *gemmout).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype)
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8Config(QuantizationConfig):
|
||||
"""Config class for W8A8 Int8 Quantization.
|
||||
|
||||
- Weight: static, per-channel, symmetric
|
||||
- Activation: dynamic, per-token, symmetric
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
return "slimquant_w4a8"
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return SlimQuantW4A8Int8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return SlimQuantW4A8Int8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
|
||||
self.quantization_config = quantization_config
|
||||
self.tritonsingleton= W8a8GetCacheJSON()
|
||||
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
n=layer.weight.shape[0]
|
||||
k=layer.weight.shape[1]
|
||||
|
||||
if self.w8a8_strategy==1:
|
||||
if {n,k} not in self.tritonsingleton.weight_shapes:
|
||||
self.tritonsingleton.weight_shapes.append({n,k})
|
||||
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
|
||||
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
|
||||
|
||||
if configs_dict:
|
||||
self.tritonsingleton.triton_json_dict.update(configs_dict)
|
||||
|
||||
for key, value in configs_dict.items():
|
||||
m=int(key.split('_')[0])
|
||||
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
|
||||
else:
|
||||
weight_data=layer.weight.data
|
||||
_weight=weight_data.T.contiguous().reshape(n,-1)
|
||||
layer.weight.data=_weight
|
||||
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
self.logical_widths = output_partition_sizes
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
input_quant_args: Optional[list[torch.Tensor]] = None,
|
||||
silu_quant_args: Optional[list[torch.Tensor]] = None
|
||||
):
|
||||
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
|
||||
# assert len(input_quant_args) == 2
|
||||
# x_q, x_scale = input_quant_args
|
||||
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
|
||||
# x_q, x_scale = silu_quant_args
|
||||
# else:
|
||||
x_q, x_scale = per_token_quant_int8(x)
|
||||
|
||||
if self.w8a8_strategy==1:
|
||||
m=x_q.shape[0]
|
||||
k=x_q.shape[1]
|
||||
n=layer.weight.shape[1]
|
||||
|
||||
if len(W8A8_TRITONJSON.triton_json_dict)==0:
|
||||
best_config=None
|
||||
|
||||
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
|
||||
if m<=16:
|
||||
m_=m
|
||||
elif m<=64:
|
||||
m_= (m + 3) & -4 #取值到最近的4的倍数
|
||||
elif m<=160:
|
||||
m_=(m + 7) & -8
|
||||
|
||||
elif m<200: #256
|
||||
m_=160
|
||||
elif m<480: #512
|
||||
m_=256
|
||||
elif m<960: #1024
|
||||
m_=512
|
||||
elif m<2048:
|
||||
m_=1024
|
||||
elif m<4096:
|
||||
m_=2048
|
||||
elif m<6000:
|
||||
m_=4096
|
||||
else:
|
||||
m_=8192
|
||||
|
||||
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
|
||||
|
||||
else:
|
||||
best_config=None
|
||||
|
||||
#if best_config==None:
|
||||
# print("m:{},n:{},k:{}".format(m,n,k))
|
||||
# print("config not found!")
|
||||
|
||||
return ops.triton_scaled_mm(x_q,
|
||||
layer.weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,best_config=best_config)
|
||||
elif self.w8a8_strategy==2:
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
layer.weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias)
|
||||
else:
|
||||
return ops.rocblas_scaled_mm(x_q,
|
||||
layer.weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias)
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8MoEMethod:
|
||||
"""MoE method for W4A8INT8.
|
||||
Supports loading INT8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
self.quant_config = quant_config
|
||||
self.tritonsingleton= W8a8GetCacheJSON()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_input_scale = None
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = None
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
E=layer.w13_weight.shape[0]
|
||||
N1=layer.w13_weight.shape[1]
|
||||
N2=layer.w2_weight.shape[1]
|
||||
K=N1//2
|
||||
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
|
||||
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
|
||||
|
||||
TOPK= self.tritonsingleton.topk
|
||||
|
||||
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True)
|
||||
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
|
||||
|
||||
#warmup
|
||||
if configs_dict:
|
||||
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
layer.w2_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
use_nn_moe: Optional[bool] = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
use_fused_gate: Optional[bool] = False,
|
||||
**_
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_fused_gate=use_fused_gate
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int4_w4a8=True,
|
||||
per_channel_quant=True,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
use_nn_moe=use_nn_moe,
|
||||
)
|
||||
@@ -1,318 +0,0 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput, StandardDispatchOutput
|
||||
import torch
|
||||
from sglang.srt import _custom_ops as ops
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from torch.nn.parameter import Parameter
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.w4a8_utils import w4a8_weight_repack_impl
|
||||
from sglang.srt.layers.quantization.base_config import (FusedMoEMethodBase, QuantizeMethodBase)
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
|
||||
try:
|
||||
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
|
||||
except Exception:
|
||||
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
|
||||
|
||||
|
||||
class MarlinMoeWorkspace:
|
||||
"""
|
||||
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
|
||||
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
|
||||
"""
|
||||
_instances = {}
|
||||
def __new__(cls, device):
|
||||
if device not in cls._instances:
|
||||
instance = super().__new__(cls)
|
||||
instance._initialized = False
|
||||
cls._instances[device] = instance
|
||||
return cls._instances[device]
|
||||
|
||||
def __init__(self, device):
|
||||
if self._initialized:
|
||||
return
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
self.workspace = torch.zeros(
|
||||
500, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
self.global_reduce_buffer = torch.zeros(
|
||||
sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
def get_buffers(self):
|
||||
return self.workspace, self.global_reduce_buffer
|
||||
|
||||
def baseline_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
scales= scale_a* scale_b.T
|
||||
gemmout= torch.mm(
|
||||
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
|
||||
output = (scales *gemmout).to(out_dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype)
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
|
||||
"""Config class for W4A8 Int8 Quantization.
|
||||
- Weight: static, per-channel, symmetric
|
||||
- Activation: dynamic, per-token, symmetric
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
return "slimquant_w4a8_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig":
|
||||
return cls()
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
|
||||
and user_quant == "slimquant_w4a8_marlin":
|
||||
return cls.get_name()
|
||||
return None
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return SlimQuantW4A8Int8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return SlimQuantW4A8Int8MarlinMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SlimQuantW4A8Int8MarlinMoEMethod:
|
||||
"""MoE method for W4A8INT8 Marlin.
|
||||
Supports loading INT8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
|
||||
if not hasattr(cls, "_initialized"):
|
||||
original_init = cls.__init__
|
||||
new_cls = type(
|
||||
cls.__name__,
|
||||
(FusedMoEMethodBase,),
|
||||
{
|
||||
"__init__": original_init,
|
||||
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
||||
},
|
||||
)
|
||||
obj = super(new_cls, new_cls).__new__(new_cls)
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, quant_config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = intermediate_size_per_partition
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_input_scale = None
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = None
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
layer.w2_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
|
||||
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
|
||||
output = fused_experts_impl_w4a8_marlin(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
workspace=workspace,
|
||||
global_reduce_buffer=global_reduce_buffer,
|
||||
inplace=True,
|
||||
use_int4_w4a8=True,
|
||||
per_channel_quant=True,
|
||||
activation=layer.moe_runner_config.activation,
|
||||
expert_map=layer.expert_map_gpu,
|
||||
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
|
||||
global_num_experts=layer.moe_runner_config.num_experts,
|
||||
w1_scale=(layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
use_nn_moe=False,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
# def _apply(
|
||||
# self,
|
||||
# layer: torch.nn.Module,
|
||||
# x: torch.Tensor,
|
||||
# router_logits: torch.Tensor,
|
||||
# top_k: int,
|
||||
# #renormalize: bool,
|
||||
# #use_grouped_topk: bool = False,
|
||||
# topk_group: Optional[int] = None,
|
||||
# num_expert_group: Optional[int] = None,
|
||||
# global_num_experts: int = -1,
|
||||
# expert_map: Optional[torch.Tensor] = None,
|
||||
# custom_routing_function: Optional[Callable] = None,
|
||||
# scoring_func: str = "softmax",
|
||||
# e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
# apply_router_weight_on_input: bool = False,
|
||||
# activation: str = "silu",
|
||||
# enable_eplb: bool = False,
|
||||
# use_nn_moe: Optional[bool] = False,
|
||||
# routed_scaling_factor: Optional[float] = None,
|
||||
# use_fused_gate: Optional[bool] = False,
|
||||
# **_
|
||||
# ) -> torch.Tensor:
|
||||
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
|
||||
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
# if enable_eplb:
|
||||
# raise NotImplementedError(
|
||||
# "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
|
||||
# # Expert selection
|
||||
# topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
# hidden_states=x,
|
||||
# router_logits=router_logits,
|
||||
# #use_grouped_topk=use_grouped_topk,
|
||||
# top_k=top_k,
|
||||
# #renormalize=renormalize,
|
||||
# topk_group=topk_group,
|
||||
# num_expert_group=num_expert_group,
|
||||
# custom_routing_function=custom_routing_function,
|
||||
# scoring_func=scoring_func,
|
||||
# e_score_correction_bias=e_score_correction_bias,
|
||||
# routed_scaling_factor=routed_scaling_factor,
|
||||
# use_fused_gate=use_fused_gate
|
||||
# )
|
||||
# workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
|
||||
# return fused_experts_impl_w4a8_marlin(
|
||||
# x,
|
||||
# layer.w13_weight,
|
||||
# layer.w2_weight,
|
||||
# topk_weights=topk_weights,
|
||||
# topk_ids=topk_ids,
|
||||
# workspace=workspace,
|
||||
# global_reduce_buffer=global_reduce_buffer,
|
||||
# inplace=True,
|
||||
# use_int4_w4a8=True,
|
||||
# per_channel_quant=True,
|
||||
# activation=activation,
|
||||
# expert_map=expert_map,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
# global_num_experts=global_num_experts,
|
||||
# w1_scale=(layer.w13_weight_scale),
|
||||
# w2_scale=(layer.w2_weight_scale),
|
||||
# a1_scale=layer.w13_input_scale,
|
||||
# a2_scale=layer.w2_input_scale,
|
||||
# use_nn_moe=use_nn_moe,
|
||||
# )
|
||||
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from lightop import awq_marlin_repack_w4a8
|
||||
use_lightop = False
|
||||
except Exception:
|
||||
use_lightop = False
|
||||
|
||||
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
|
||||
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
|
||||
|
||||
Args:
|
||||
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
|
||||
"""
|
||||
if tensor_int8.dtype != torch.int8:
|
||||
raise ValueError("Input tensor must be of type torch.int8")
|
||||
|
||||
N, K_half = tensor_int8.shape
|
||||
tensor_uint8 = tensor_int8.to(torch.uint8)
|
||||
high4 = tensor_uint8 & 0x0F
|
||||
low4 = (tensor_uint8 >> 4) & 0x0F
|
||||
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
|
||||
unpacked[:, 0::2] = low4.to(torch.int32)
|
||||
unpacked[:, 1::2] = high4.to(torch.int32)
|
||||
|
||||
return unpacked
|
||||
|
||||
def get_weight_perms(interleave: bool=True):
|
||||
perm = []
|
||||
for i in range(64):
|
||||
|
||||
for col in range(4):
|
||||
cur_col = (i % 16) * 4 + col
|
||||
for row in range(8):
|
||||
cur_row = (i // 16) * 8 + row
|
||||
cur_idx = cur_row * 64 + cur_col
|
||||
perm.append(cur_idx)
|
||||
|
||||
perm = np.array(perm)
|
||||
if interleave:
|
||||
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
|
||||
perm = perm.reshape((-1, 8))[:, interleave].ravel()
|
||||
|
||||
perm = torch.from_numpy(perm)
|
||||
|
||||
return perm
|
||||
|
||||
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
|
||||
size_k, size_n = q_w.shape
|
||||
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
|
||||
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
|
||||
|
||||
orig_device = q_w.device
|
||||
q_w = q_w.contiguous().to(torch.int32)
|
||||
M, N = q_w.shape
|
||||
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
|
||||
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
|
||||
for i in range(pack_factor):
|
||||
q_packed += q_w[:, i::pack_factor] << (4 * i)
|
||||
|
||||
return q_packed
|
||||
|
||||
def w4a8_2_marlin_weight(w4a8_w):
|
||||
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
|
||||
full_w4a8_w = full_w4a8_w.T
|
||||
weight_perm = get_weight_perms()
|
||||
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
|
||||
return marlin_q_w
|
||||
|
||||
def w4a8_weight_repack_impl(input):
|
||||
if use_lightop:
|
||||
size_batch = input.shape[0]
|
||||
size_n = input.shape[1]
|
||||
size_k = input.shape[2] * 2
|
||||
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
|
||||
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
|
||||
else:
|
||||
w_marlin_list = []
|
||||
for e in range(input.shape[0]):
|
||||
w_marlin_in = w4a8_2_marlin_weight(input[e])
|
||||
w_marlin_list.append(w_marlin_in)
|
||||
output = torch.stack(w_marlin_list, dim=0)
|
||||
|
||||
return output
|
||||
@@ -19,6 +19,10 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import is_npu, set_weight_attrs
|
||||
|
||||
_is_npu = is_npu()
|
||||
if not _is_npu:
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
|
||||
@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
||||
x.dtype,
|
||||
True, # is_vnni
|
||||
)
|
||||
|
||||
x_q, x_scale = per_token_quant_int8(x)
|
||||
|
||||
x_q_2d = x_q.view(-1, x_q.shape[-1])
|
||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||
|
||||
output = int8_scaled_mm(
|
||||
x_q_2d,
|
||||
layer.weight,
|
||||
x_scale_2d,
|
||||
layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
return int8_scaled_mm(
|
||||
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
|
||||
return output.view(output_shape)
|
||||
|
||||
|
||||
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for INT8.
|
||||
@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
|
||||
|
||||
class NPU_W8A8LinearMethodMTImpl:
|
||||
@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
|
||||
|
||||
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
||||
|
||||
@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
get_compiler_backend,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
||||
else:
|
||||
FusedSetKVBufferArg = None
|
||||
|
||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||
if _use_aiter:
|
||||
from aiter.rotary_embedding import get_rope as aiter_get_rope
|
||||
|
||||
if is_npu():
|
||||
import torch_npu
|
||||
|
||||
NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
|
||||
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
assert (
|
||||
fused_set_kv_buffer_arg is None
|
||||
), "fused_set_kv_buffer_arg is not supported for native implementation"
|
||||
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-npu implementation of forward()."""
|
||||
assert (
|
||||
fused_set_kv_buffer_arg is None
|
||||
), "fused_set_kv_buffer_arg is not supported for npu implementation"
|
||||
import os
|
||||
|
||||
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
||||
return self.forward_native(
|
||||
positions, query, key, offsets, fused_set_kv_buffer_arg
|
||||
)
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
else:
|
||||
rotary_mode = "half"
|
||||
if self.is_neox_style:
|
||||
@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert (
|
||||
fused_set_kv_buffer_arg is None
|
||||
), "fused_set_kv_buffer_arg is not supported for cpu implementation"
|
||||
|
||||
positions = torch.add(positions, offsets) if offsets is not None else positions
|
||||
if _is_cpu_amx_available:
|
||||
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
|
||||
self.is_neox_style,
|
||||
)
|
||||
else:
|
||||
return self.forward_native(
|
||||
positions, query, key, offsets, fused_set_kv_buffer_arg
|
||||
)
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
@@ -789,7 +765,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
cos_for_key = cos[:, 0, ...]
|
||||
sin_for_key = sin[:, 0, ...]
|
||||
key_rot = key_rot * cos_for_key + rotate_fn(key_rot) * sin_for_key
|
||||
#key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
@@ -1059,7 +1038,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
||||
)
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
@torch.compile(dynamic=True)
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -1207,7 +1186,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
time_tensor_long = time_tensor.long()
|
||||
t_index = time_tensor_long.flatten()
|
||||
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
|
||||
elif model_type == "qwen2_vl":
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
@@ -1918,30 +1897,17 @@ def apply_rotary_pos_emb_npu(
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim=1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
|
||||
|
||||
Args:
|
||||
q: [num_tokens, num_heads, head_size]
|
||||
k: [num_tokens, num_kv_heads, head_size]
|
||||
cos: [num_tokens, head_size]
|
||||
sin: [num_tokens, head_size]
|
||||
"""
|
||||
if (
|
||||
cos.dim() != 2
|
||||
or q.dim() != 3
|
||||
or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
|
||||
or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
|
||||
):
|
||||
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
|
||||
if q.shape[1] != 128:
|
||||
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
||||
sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
q_embed = q_embed.squeeze(0)
|
||||
k_embed = k_embed.squeeze(0)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
cos = torch.transpose(cos, 1, 2)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
sin = torch.transpose(sin, 1, 2)
|
||||
q = torch.transpose(q, 1, 2)
|
||||
k = torch.transpose(k, 1, 2)
|
||||
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q_embed = torch.transpose(q_embed, 1, 2)
|
||||
k_embed = torch.transpose(k_embed, 1, 2)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
|
||||
return None
|
||||
|
||||
|
||||
def pad_or_narrow_weight(
|
||||
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
|
||||
) -> torch.Tensor:
|
||||
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
|
||||
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
|
||||
|
||||
if valid_size > 0:
|
||||
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
|
||||
pad_shape = list(loaded_weight.shape)
|
||||
pad_shape[input_dim] = shard_size - valid_size
|
||||
pad = torch.zeros(
|
||||
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
||||
)
|
||||
return torch.cat([loaded_slice, pad], dim=input_dim)
|
||||
|
||||
# All padding
|
||||
pad_shape = list(loaded_weight.shape)
|
||||
pad_shape[input_dim] = shard_size
|
||||
return torch.zeros(
|
||||
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
||||
)
|
||||
|
||||
|
||||
class PPMissingLayer(torch.nn.Identity):
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
||||
|
||||
@@ -5,7 +5,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.utils import cached_triton_kernel
|
||||
from sglang.utils import cached_triton_kernel
|
||||
|
||||
|
||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||
|
||||
@@ -3,7 +3,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.utils import cached_triton_kernel
|
||||
from sglang.utils import cached_triton_kernel
|
||||
|
||||
|
||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||
|
||||
@@ -275,17 +275,43 @@ class HiCacheController:
|
||||
and self.storage_config.tp_rank != 0
|
||||
)
|
||||
|
||||
# Use storage backend factory for dynamic backend creation
|
||||
from sglang.srt.mem_cache.storage import StorageBackendFactory
|
||||
if storage_backend == "file":
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
||||
|
||||
try:
|
||||
self.storage_backend = StorageBackendFactory.create_backend(
|
||||
storage_backend, self.storage_config, self.mem_pool_host
|
||||
self.storage_backend = HiCacheFile(self.storage_config)
|
||||
elif storage_backend == "nixl":
|
||||
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||
|
||||
self.storage_backend = HiCacheNixl()
|
||||
elif storage_backend == "mooncake":
|
||||
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
||||
MooncakeStore,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Failed to create storage backend: {e}") from e
|
||||
|
||||
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
|
||||
self.storage_backend = MooncakeStore(self.storage_config)
|
||||
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||
assert self.mem_pool_host.layout == "page_first"
|
||||
elif storage_backend == "hf3fs":
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||
HiCacheHF3FS,
|
||||
)
|
||||
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
dtype = mem_pool_host.dtype
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
bytes_per_page, dtype, self.storage_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
)
|
||||
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
@@ -309,10 +335,18 @@ class HiCacheController:
|
||||
# Select the get and set functions
|
||||
self.page_get_func = self._generic_page_get
|
||||
self.page_set_func = self._generic_page_set
|
||||
|
||||
if self.storage_backend_type in ["hf3fs", "mooncake"]:
|
||||
self.page_get_func = self._page_get_zero_copy
|
||||
self.page_set_func = self._page_set_zero_copy
|
||||
self.batch_exists_func = self.storage_backend.batch_exists
|
||||
self.is_3fs_zerocopy = (
|
||||
self.storage_backend_type == "hf3fs"
|
||||
and self.mem_pool_host.layout == "page_first"
|
||||
)
|
||||
if self.storage_backend_type == "mooncake":
|
||||
self.page_get_func = self._mooncake_page_get
|
||||
self.page_set_func = self._mooncake_page_set
|
||||
elif self.is_3fs_zerocopy:
|
||||
self.page_get_func = self._3fs_zero_copy_page_get
|
||||
self.page_set_func = self._3fs_zero_copy_page_set
|
||||
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
||||
|
||||
self.device = self.mem_pool_device.device
|
||||
self.layer_num = self.mem_pool_device.layer_num
|
||||
@@ -436,6 +470,7 @@ class HiCacheController:
|
||||
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
||||
if host_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_write(host_indices)
|
||||
self.write_queue.append(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
@@ -459,6 +494,7 @@ class HiCacheController:
|
||||
self.mem_pool_host.backup_from_device_all_layer(
|
||||
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
||||
)
|
||||
self.mem_pool_host.complete_io(op.host_indices)
|
||||
finish_event.record()
|
||||
# NOTE: We must save the host indices and device indices here,
|
||||
# this is because we need to guarantee that these tensors are
|
||||
@@ -482,6 +518,7 @@ class HiCacheController:
|
||||
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
||||
if device_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
self.load_queue.append(
|
||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||
)
|
||||
@@ -526,6 +563,7 @@ class HiCacheController:
|
||||
self.io_backend,
|
||||
)
|
||||
producer_event.complete(i)
|
||||
self.mem_pool_host.complete_io(op.host_indices)
|
||||
# NOTE: We must save the host indices and device indices here,
|
||||
# this is because we need to guarantee that these tensors are
|
||||
# still alive when the load stream is executing.
|
||||
@@ -543,16 +581,29 @@ class HiCacheController:
|
||||
)
|
||||
return producer_id
|
||||
|
||||
def evict_device(self, device_indices: torch.Tensor) -> int:
|
||||
self.mem_pool_device_allocator.free(device_indices)
|
||||
return len(device_indices)
|
||||
def evict_device(
|
||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||
) -> int:
|
||||
if self.mem_pool_host.is_synced(host_indices):
|
||||
self.mem_pool_device_allocator.free(device_indices)
|
||||
self.mem_pool_host.update_backup(host_indices)
|
||||
return len(device_indices)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||
)
|
||||
|
||||
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
||||
if not backup_only:
|
||||
raise ValueError("Other eviction policies are not supported yet.")
|
||||
|
||||
self.mem_pool_host.free(host_indices)
|
||||
return len(host_indices)
|
||||
if self.mem_pool_host.is_backup(host_indices):
|
||||
self.mem_pool_host.free(host_indices)
|
||||
return len(host_indices)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||
)
|
||||
|
||||
def prefetch(
|
||||
self,
|
||||
@@ -579,19 +630,42 @@ class HiCacheController:
|
||||
for chunk in chunks:
|
||||
self.host_mem_release_queue.put(chunk)
|
||||
|
||||
def _page_get_zero_copy(self, operation, hash_values, host_indices):
|
||||
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
|
||||
inc = 0
|
||||
for i in range(len(hash_values)):
|
||||
if not results[i]:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
||||
)
|
||||
break
|
||||
inc += self.page_size
|
||||
operation.increment(inc)
|
||||
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
||||
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
||||
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
||||
return hit_page_num
|
||||
|
||||
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
||||
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
||||
hash_values, host_indices
|
||||
)
|
||||
page_data = self.storage_backend.batch_get(hashes, dsts)
|
||||
if page_data:
|
||||
inc = self.page_size * len(hashes) // factor
|
||||
operation.increment(inc)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
||||
)
|
||||
|
||||
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||
hash_values,
|
||||
host_indices,
|
||||
self.storage_config.tp_rank,
|
||||
)
|
||||
get_result = self.storage_backend.batch_get(
|
||||
key_strs,
|
||||
target_locations=buffer_ptrs,
|
||||
target_sizes=buffer_sizes,
|
||||
)
|
||||
if get_result != len(hash_values):
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed or partially failed."
|
||||
)
|
||||
if get_result != 0:
|
||||
operation.increment(get_result * self.page_size)
|
||||
|
||||
# todo: deprecate
|
||||
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||
dummy_page_dst = [
|
||||
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
||||
@@ -681,7 +755,7 @@ class HiCacheController:
|
||||
batch_tokens[i : i + self.page_size], last_hash
|
||||
)
|
||||
batch_hashes.append(last_hash)
|
||||
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
||||
hit_page_num = self.batch_exists_func(batch_hashes)
|
||||
hash_value.extend(batch_hashes[:hit_page_num])
|
||||
storage_query_count += hit_page_num * self.page_size
|
||||
if hit_page_num < len(batch_hashes):
|
||||
@@ -750,16 +824,34 @@ class HiCacheController:
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
# todo: deprecate
|
||||
# non-zero copy
|
||||
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
||||
data = [
|
||||
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
||||
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
|
||||
for i in range(len(hash_values))
|
||||
]
|
||||
return self.storage_backend.batch_set(hash_values, data)
|
||||
|
||||
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
|
||||
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
|
||||
# zero copy
|
||||
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
|
||||
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||
hash_values,
|
||||
host_indices,
|
||||
self.storage_config.tp_rank,
|
||||
)
|
||||
success = self.storage_backend.batch_set(
|
||||
key_strs,
|
||||
target_locations=buffer_ptrs,
|
||||
target_sizes=buffer_sizes,
|
||||
)
|
||||
return success
|
||||
|
||||
# zero copy
|
||||
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
||||
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
||||
hash_values, host_indices
|
||||
)
|
||||
return self.storage_backend.batch_set(hashes, dsts)
|
||||
|
||||
# Backup batch by batch
|
||||
def _page_backup(self, operation):
|
||||
|
||||
@@ -35,7 +35,6 @@ else:
|
||||
Image = Any
|
||||
|
||||
|
||||
# Parameters for a session
|
||||
@dataclass
|
||||
class SessionParams:
|
||||
id: Optional[str] = None
|
||||
@@ -133,23 +132,18 @@ class GenerateReqInput:
|
||||
# Conversation id used for tracking requests
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
# Label for the request
|
||||
label: Optional[str] = None
|
||||
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[Union[List[str], str]] = None
|
||||
|
||||
# Whether to disallow logging for this request (e.g. due to ZDR)
|
||||
no_logs: bool = False
|
||||
|
||||
# For custom metric labels
|
||||
custom_labels: Optional[Dict[str, str]] = None
|
||||
|
||||
# (Deprecated, please use custom_labels) Label for the request
|
||||
label: Optional[str] = None
|
||||
# (Internal) Whether to return bytes for image generation
|
||||
# Image gen grpc migration
|
||||
return_bytes: bool = False
|
||||
|
||||
# For customer metric labels
|
||||
customer_labels: Optional[Dict[str, str]] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return (
|
||||
has_valid_data(self.image_data)
|
||||
@@ -548,11 +542,8 @@ class GenerateReqInput:
|
||||
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
||||
),
|
||||
conversation_id=self.conversation_id,
|
||||
priority=self.priority,
|
||||
extra_key=self.extra_key,
|
||||
no_logs=self.no_logs,
|
||||
custom_labels=self.custom_labels,
|
||||
label=self.label,
|
||||
priority=self.priority,
|
||||
return_bytes=self.return_bytes,
|
||||
)
|
||||
|
||||
@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
|
||||
# For dp balance
|
||||
dp_balance_id: int = -1
|
||||
|
||||
# Label for the request
|
||||
label: Optional[str] = None
|
||||
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
# Extra key for classifying the request (e.g. cache_salt)
|
||||
extra_key: Optional[str] = None
|
||||
|
||||
# Whether to disallow logging for this request (e.g. due to ZDR)
|
||||
no_logs: bool = False
|
||||
# Image gen grpc migration
|
||||
return_bytes: bool = False
|
||||
|
||||
# tracing context
|
||||
trace_context: Optional[Dict] = None
|
||||
|
||||
# (Deprecated, please use custom_labels) Label for the request
|
||||
label: Optional[str] = None
|
||||
# (Internal) Whether to return bytes for image generation
|
||||
return_bytes: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchTokenizedGenerateReqInput:
|
||||
|
||||
@@ -507,7 +507,6 @@ def embed_mm_inputs(
|
||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||
use_deepstack: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Embed multimodal inputs and integrate them with text token embeddings.
|
||||
@@ -523,7 +522,7 @@ def embed_mm_inputs(
|
||||
Returns:
|
||||
Combined embedding tensor with multimodal content integrated
|
||||
"""
|
||||
other_info = {}
|
||||
|
||||
if mm_inputs_list is None:
|
||||
return None
|
||||
|
||||
@@ -533,7 +532,7 @@ def embed_mm_inputs(
|
||||
for mm_inputs in mm_inputs_list:
|
||||
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
||||
|
||||
embeddings, masks, deepstack_embeddings = [], [], []
|
||||
embeddings, masks = [], []
|
||||
# 2. Get multimodal embedding separately
|
||||
# Try get mm embedding if any
|
||||
for modality in Modality.all():
|
||||
@@ -579,12 +578,6 @@ def embed_mm_inputs(
|
||||
extend_length=extend_seq_lens,
|
||||
items_offset_list=items_offsets,
|
||||
)
|
||||
|
||||
if use_deepstack and embedding is not None:
|
||||
embedding, deepstack_embedding = (
|
||||
multimodal_model.separate_deepstack_embeds(embedding)
|
||||
)
|
||||
deepstack_embeddings += [deepstack_embedding]
|
||||
embeddings += [embedding]
|
||||
masks += [mask]
|
||||
|
||||
@@ -598,37 +591,13 @@ def embed_mm_inputs(
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
|
||||
# 4. scatter embeddings into input embedding
|
||||
|
||||
# deepstack embedding
|
||||
if use_deepstack:
|
||||
num_deepstack_embeddings = (
|
||||
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
|
||||
)
|
||||
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
|
||||
inputs_embeds.shape[-1] * num_deepstack_embeddings,
|
||||
)
|
||||
|
||||
input_deepstack_embeds = torch.zeros(
|
||||
deepstack_embedding_shape,
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
|
||||
other_info["input_deepstack_embeds"] = input_deepstack_embeds
|
||||
|
||||
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
|
||||
for embedding, mask in zip(embeddings, masks):
|
||||
if embedding is None or mask is None:
|
||||
continue
|
||||
# in-place update
|
||||
indices = torch.where(mask.squeeze(dim=-1))[0]
|
||||
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
|
||||
if use_deepstack:
|
||||
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
|
||||
return inputs_embeds, other_info
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def general_mm_embed_routine(
|
||||
@@ -640,7 +609,6 @@ def general_mm_embed_routine(
|
||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||
] = None,
|
||||
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||
use_deepstack: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -652,7 +620,6 @@ def general_mm_embed_routine(
|
||||
language_model: Base language model to use
|
||||
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
||||
placeholder_tokens: Token IDs for multimodal placeholders
|
||||
use_deepstack: Whether to use deepstack embeddings
|
||||
**kwargs: Additional arguments passed to language model
|
||||
|
||||
Returns:
|
||||
@@ -678,20 +645,16 @@ def general_mm_embed_routine(
|
||||
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
|
||||
if forward_batch.mm_inputs[i] is not None
|
||||
]
|
||||
inputs_embeds, other_info = embed_mm_inputs(
|
||||
inputs_embeds = embed_mm_inputs(
|
||||
mm_inputs_list=mm_inputs_list,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
input_ids=input_ids,
|
||||
multimodal_model=multimodal_model,
|
||||
input_embedding=embed_tokens,
|
||||
multimodal_model=multimodal_model,
|
||||
data_embedding_func_mapping=data_embedding_funcs,
|
||||
placeholder_tokens=placeholder_tokens,
|
||||
use_deepstack=use_deepstack,
|
||||
)
|
||||
# add for qwen3_vl deepstack
|
||||
if use_deepstack:
|
||||
kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
|
||||
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
||||
# just being defensive here
|
||||
forward_batch.mm_inputs = None
|
||||
|
||||
@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
|
||||
PROCESSOR_MAPPING = {}
|
||||
|
||||
|
||||
def import_processors(package_name: str):
|
||||
def import_processors():
|
||||
package_name = "sglang.srt.multimodal.processors"
|
||||
package = importlib.import_module(package_name)
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def _resolve_future_token_ids(input_ids, future_token_ids_map):
|
||||
input_ids[:] = torch.where(
|
||||
input_ids < 0,
|
||||
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
||||
input_ids,
|
||||
)
|
||||
|
||||
|
||||
class FutureMap:
|
||||
def __init__(
|
||||
self,
|
||||
max_running_requests: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.future_ct = 0
|
||||
# A factor of 3 is used to avoid collision in the circular buffer.
|
||||
self.future_limit = max_running_requests * 3
|
||||
# A factor of 5 is used to ensure the buffer is large enough.
|
||||
self.future_buffer_len = max_running_requests * 5
|
||||
self.device = device
|
||||
|
||||
self.token_ids_buf = torch.empty(
|
||||
(self.future_buffer_len,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
def update_ct(self, bs: int) -> int:
|
||||
"""Update the circular buffer pointer and return the current pointer."""
|
||||
cur_future_ct = self.future_ct
|
||||
self.future_ct = (cur_future_ct + bs) % self.future_limit
|
||||
return cur_future_ct
|
||||
|
||||
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
|
||||
input_ids = model_worker_batch.input_ids
|
||||
_resolve_future_token_ids(input_ids, self.token_ids_buf)
|
||||
|
||||
def update_next_future(self, future_ct: int, bs: int):
|
||||
return torch.arange(
|
||||
-(future_ct + 1),
|
||||
-(future_ct + 1 + bs),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
|
||||
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
|
||||
@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import flatten_nested_list, support_triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
||||
"disable_radix_cache",
|
||||
"enable_dp_lm_head",
|
||||
"enable_fp32_lm_head",
|
||||
"flashinfer_mxfp4_moe_precision",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
"moe_dense_tp_size",
|
||||
@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_custom_logit_processor",
|
||||
"disaggregation_mode",
|
||||
"enable_deterministic_inference",
|
||||
"nsa_prefill",
|
||||
"nsa_decode",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -492,7 +493,7 @@ class Req:
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# extra key for classifying the request (e.g. cache_salt)
|
||||
# extra key for classifying the request (e.g. lora_id, cache_salt)
|
||||
if lora_id is not None:
|
||||
extra_key = (
|
||||
extra_key or ""
|
||||
@@ -608,8 +609,6 @@ class Req:
|
||||
) = None
|
||||
self.hidden_states: List[List[float]] = []
|
||||
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
||||
self.output_topk_p = None
|
||||
self.output_topk_index = None
|
||||
|
||||
# Embedding (return values)
|
||||
self.embedding = None
|
||||
@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
|
||||
None
|
||||
)
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
] = None
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if (
|
||||
self.spec_algorithm.is_eagle()
|
||||
or self.spec_algorithm.is_standalone()
|
||||
or self.spec_algorithm.is_ngram()
|
||||
or self.spec_algorithm.is_lookahead()
|
||||
):
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
||||
if self.spec_info:
|
||||
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
|
||||
has_been_filtered = False
|
||||
else:
|
||||
has_been_filtered = True
|
||||
self.spec_info.filter_batch(
|
||||
new_indices=keep_indices_device,
|
||||
has_been_filtered=has_been_filtered,
|
||||
)
|
||||
self.spec_info.filter_batch(keep_indices_device)
|
||||
|
||||
def merge_batch(self, other: "ScheduleBatch"):
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
|
||||
None
|
||||
)
|
||||
spec_info: Optional[
|
||||
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
|
||||
] = None
|
||||
# If set, the output of the batch contains the hidden states of the run.
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
hicache_consumer_index: int = -1
|
||||
|
||||
@@ -318,6 +318,7 @@ class PrefillAdder:
|
||||
new_token_ratio: float,
|
||||
rem_input_tokens: int,
|
||||
rem_chunk_tokens: Optional[int],
|
||||
max_prefill_bs: Optional[int],
|
||||
mixed_with_decode_tokens: int = 0,
|
||||
priority_scheduling_preemption_threshold: int = 0,
|
||||
):
|
||||
@@ -358,6 +359,10 @@ class PrefillAdder:
|
||||
priority_scheduling_preemption_threshold
|
||||
)
|
||||
|
||||
self.max_prefill_bs = (
|
||||
max_prefill_bs if max_prefill_bs is not None else 2147483647
|
||||
)
|
||||
|
||||
def _get_running_request_total_token_offset(self, req: Req) -> int:
|
||||
return (
|
||||
min(
|
||||
@@ -549,6 +554,9 @@ class PrefillAdder:
|
||||
def add_one_req(
|
||||
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
|
||||
):
|
||||
if len(self.can_run_list) >= self.max_prefill_bs:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user